Бонусный урок Spring AI

This commit is contained in:
gren-d-EYDenisova
2026-04-30 21:25:13 +03:00
parent ca4fd7994a
commit ac8d170b13
20 changed files with 488 additions and 28 deletions
+3
View File
@@ -0,0 +1,3 @@
/.idea/
/target/
/src/main/resources/env.properties
+89
View File
@@ -0,0 +1,89 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.4.3</version>
<relativePath/>
</parent>
<groupId>com.example</groupId>
<artifactId>agile-story-generator</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>agile-story-generator</name>
<description>User Story generator</description>
<properties>
<java.version>21</java.version>
<spring-ai.version>1.0.0-M6</spring-ai.version>
</properties>
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<!-- PostgreSQL -->
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.hibernate.orm</groupId>
<artifactId>hibernate-vector</artifactId>
<version>6.6.8.Final</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
@@ -0,0 +1,11 @@
package com.example.agile;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class AgileApplication {
public static void main(String[] args) {
SpringApplication.run(AgileApplication.class, args);
}
}
@@ -0,0 +1,33 @@
package com.example.agile.controllers;
import com.example.agile.controllers.dto.ChatRequest;
import com.example.agile.controllers.dto.TaskDto;
import com.example.agile.services.ChatService;
import com.example.agile.services.GenerationService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/tasks")
public class TaskController {
@Autowired
private ChatService chatService;
@Autowired
private GenerationService generationService;
@PostMapping("/message")
public ResponseEntity<String> message(@RequestBody ChatRequest request) {
return ResponseEntity.ok(chatService.sendMessage(request.message()));
}
@PostMapping
public ResponseEntity<TaskDto> generate(@RequestBody ChatRequest request) {
return ResponseEntity.ok(TaskDto.toDto(generationService.generateOrGetTask(request.message())));
}
}
@@ -0,0 +1,4 @@
package com.example.agile.controllers.dto;
public record ChatRequest(String message) {
}
@@ -0,0 +1,14 @@
package com.example.agile.controllers.dto;
import com.example.agile.entities.Task;
public record TaskDto(String requirement,
String userStory,
String acceptanceCriteria,
Integer complexity) {
public static TaskDto toDto(Task task) {
return new TaskDto(task.getRequirement(), task.getUserStory(), task.getAcceptanceCriteria(), task.getComplexity());
}
}
@@ -0,0 +1,40 @@
package com.example.agile.entities;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Table;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.type.SqlTypes;
import java.util.UUID;
@Entity
@Table(name = "tasks")
@NoArgsConstructor
@Data
@AllArgsConstructor
public class Task {
@Id
private UUID id;
@Column(nullable = false, length = 2000)
private String requirement;
@Column(nullable = false, length = 500)
private String userStory;
@Column(nullable = false, length = 2000)
private String acceptanceCriteria;
@Column(nullable = false)
private Integer complexity;
@JdbcTypeCode(SqlTypes.VECTOR)
@Column(columnDefinition = "vector(768)")
private float[] embedding;
}
@@ -0,0 +1,26 @@
package com.example.agile.repositories;
import com.example.agile.entities.Task;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.UUID;
@Repository
public interface TaskRepository extends JpaRepository<Task, UUID> {
@Query("SELECT DISTINCT t FROM Task t WHERE t.complexity = :complexity")
List<Task> findByComplexity(@Param("complexity") Integer complexity);
@Query(value = """
SELECT * FROM tasks
ORDER BY embedding <=> cast(:embedding as vector)
LIMIT :limit
""", nativeQuery = true)
List<Task> findTopKSimilar(@Param("embedding") String embeddingString,
@Param("limit") int limit);
}
@@ -0,0 +1,30 @@
package com.example.agile.services;
import com.example.agile.tools.TaskTool;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
@Service
public class ChatService {
private final ChatClient chatClient;
public ChatService(OllamaChatModel chatModel, TaskTool taskTool) {
chatClient = ChatClient.builder(chatModel)
.defaultSystem("You are an Agile expert." +
"You have access to the task database (User Task)." +
"When a user asks about tasks" +
"use the available functions to obtain up-to-date information.")
.defaultTools(taskTool)
.build();
}
public String sendMessage(String message) {
return chatClient.prompt()
.user(message)
.call()
.content();
}
}
@@ -0,0 +1,63 @@
package com.example.agile.services;
import com.example.agile.entities.Task;
import com.example.agile.repositories.TaskRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@RequiredArgsConstructor
public class EmbeddingSimilarityService {
@Value("${story.similarity.threshold:0.85}")
private double threshold;
@Value("${story.similarity.top-k:5}")
private int topK;
private final TaskRepository taskRepository;
public double cosineSimilarityFromDistance(double cosineDistance) {
return 1.0 - cosineDistance;
}
public Task findSimilarTask(float[] newEmbedding) {
String embeddingStr = floatArrayToPgVectorString(newEmbedding);
List<Task> candidates = taskRepository.findTopKSimilar(embeddingStr, topK);
for (Task candidate : candidates) {
double distance = computeCosineDistance(newEmbedding, candidate.getEmbedding());
double similarity = 1.0 - distance;
if (similarity >= threshold) {
return candidate;
}
}
return null;
}
private double computeCosineDistance(float[] a, float[] b) {
if (a.length != b.length) return 1.0;
double dot = 0.0, normA = 0.0, normB = 0.0;
for (int i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
if (normA == 0.0 || normB == 0.0) return 1.0;
double cosineSimilarity = dot / (Math.sqrt(normA) * Math.sqrt(normB));
return 1.0 - cosineSimilarity;
}
private String floatArrayToPgVectorString(float[] arr) {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < arr.length; i++) {
if (i > 0) sb.append(",");
sb.append(arr[i]);
}
sb.append("]");
return sb.toString();
}
}
@@ -0,0 +1,64 @@
package com.example.agile.services;
import com.example.agile.entities.Task;
import com.example.agile.repositories.TaskRepository;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.UUID;
@RequiredArgsConstructor
@Service
public class GenerationService {
private final ChatModel chatModel;
private final EmbeddingModel embeddingModel;
private final TaskRepository taskRepository;
private final EmbeddingSimilarityService similarityService;
private final ObjectMapper objectMapper = new ObjectMapper();
private static final String PROMPT_TEMPLATE = """
You are an Agile expert. Generate a User Story, acceptance criteria, and complexity estimate for the following requirement:
"%s"
The response should be in JSON format without explanation:
{
"userStory": "As a <role>, I want <action> so that <goal>",
"acceptanceCriteria": "- item 1\\n- item 2\\n- item 3",
"complexity": a number from 1 to 5
}
""";
@Transactional
public Task generateOrGetTask(String requirement) {
float[] requirementEmbedding = embeddingModel.embed(requirement);
Task existingTask = similarityService.findSimilarTask(requirementEmbedding);
if (existingTask != null) {
return existingTask;
}
String prompt = String.format(PROMPT_TEMPLATE, requirement);
String aiResponse = chatModel.call(prompt);
Task newTask = parseTask(requirement, aiResponse, requirementEmbedding);
return taskRepository.save(newTask);
}
private Task parseTask(String requirement, String aiResponse, float[] requirementEmbedding) {
try {
JsonNode root = objectMapper.readTree(aiResponse);
String userStory = root.get("userStory").asText();
String acceptanceCriteria = root.get("acceptanceCriteria").asText();
int complexity = root.get("complexity").asInt();
return new Task(UUID.randomUUID(), requirement, userStory, acceptanceCriteria, complexity, requirementEmbedding);
} catch (Exception e) {
throw new RuntimeException("Ошибка парсинга ответа модели: " + aiResponse, e);
}
}
}
@@ -0,0 +1,26 @@
package com.example.agile.services;
import com.example.agile.controllers.dto.TaskDto;
import com.example.agile.repositories.TaskRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
@Service
@RequiredArgsConstructor
public class TaskService {
private final TaskRepository repository;
@Transactional(readOnly = true)
public List<TaskDto> findByComplexity(Integer complexity) {
return repository.findByComplexity(complexity).stream().map(TaskDto::toDto).toList();
}
@Transactional(readOnly = true)
public List<TaskDto> findAll() {
return repository.findAll().stream().map(TaskDto::toDto).toList();
}
}
@@ -0,0 +1,27 @@
package com.example.agile.tools;
import com.example.agile.controllers.dto.TaskDto;
import com.example.agile.services.TaskService;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
@RequiredArgsConstructor
public class TaskTool {
private final TaskService taskService;
@Tool(name = "getAllTask", description = "Get all tasks (User Stories) from the database")
public List<TaskDto> getAllTask() {
return taskService.findAll();
}
@Tool(name = "getTasksByComplexity", description = "Get all tasks by difficulty")
public List<TaskDto> getTasksByComplexity(
@ToolParam(description = "Task difficulty level (1 to 5)") Integer complexity) {
return taskService.findByComplexity(complexity);
}
}
@@ -0,0 +1,38 @@
spring:
application:
name: agile-story-generator
datasource:
url: jdbc:postgresql://localhost:5437/agiledb?reWriteBatchedInserts=true&currentSchema=public
username: agile_user
password: agile_password
driver-class-name: org.postgresql.Driver
jpa:
database-platform: org.hibernate.dialect.PostgreSQLDialect
hibernate:
ddl-auto: update
show-sql: true
properties:
hibernate:
types:
print:
banner: false
ai:
ollama:
base-url: http://localhost:11434
chat:
model: mistral:latest
options:
temperature: 0.3
tool-choice: auto
embedding:
model: nomic-embed-text:latest
story:
similarity:
threshold: 0.85
logging:
level:
org.springframework.ai: DEBUG
@@ -0,0 +1,10 @@
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS tasks (
id BIGSERIAL PRIMARY KEY,
requirement TEXT NOT NULL,
user_story VARCHAR(500) NOT NULL,
acceptance_criteria TEXT NOT NULL,
complexity INTEGER NOT NULL,
embedding vector(768) NOT NULL
);
@@ -1,6 +1,6 @@
spring:
shell:
interactive:
enabled: false
enabled: true
main:
allow-circular-references: true
@@ -5,12 +5,10 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import static org.junit.jupiter.api.Assertions.*;
//@SpringBootTest(classes = {Service1.class, Service2.class})
//@TestPropertySource("classpath:test.properties")
//@DirtiesContext(classMode = DirtiesContext.ClassMode.BEFORE_CLASS)
@TestPropertySource("classpath:test.properties")
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
@SpringBootTest
class IntegrationTest1 {
@@ -32,6 +30,7 @@ class IntegrationTest1 {
}
@Test
@DirtiesContext(methodMode = DirtiesContext.MethodMode.AFTER_METHOD)
void test2() {
System.out.println(service1.getName() + ": " + service1.getState());
System.out.println(service2.getName() + ": " + service2.getState());
@@ -9,6 +9,7 @@ import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.context.ContextConfiguration;
import ru.otus.example.testconfigurationdemo.family.FamilyMember;
import ru.otus.example.testconfigurationdemo.family.parents.Mother;
import ru.otus.example.testconfigurationdemo.family.pets.Dog;
import java.util.Map;
@@ -16,21 +17,11 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
@DisplayName("В NestedConfigurationDemoTest семья должна ")
@SpringBootTest
//@SpringBootTest(classes = Dog.class)
//@ContextConfiguration(classes = Dog.class)
@SpringBootTest(classes = {
Dog.class, Mother.class
})
public class NestedConfigurationDemoTest {
@ComponentScan("ru.otus.example.testconfigurationdemo.family.pets")
@Configuration
static class NestedConfiguration {
/*
@Bean
FamilyMember dog() {
return new Dog();
}
*/
}
@Autowired
private Map<String, FamilyMember> family;
@@ -16,9 +16,7 @@ import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
@DisplayName("В NestedTestConfigurationDemoTest семья должна ")
@SpringBootTest
//@SpringBootTest(properties = "spring.main.allow-bean-definition-overriding=true")
//@TestPropertySource(properties = "spring.main.allow-bean-definition-overriding=true")
@SpringBootTest(properties = "spring.main.allow-bean-definition-overriding=true")
public class NestedTestConfigurationDemoTest {
@TestConfiguration
@@ -28,7 +26,6 @@ public class NestedTestConfigurationDemoTest {
return new Father();
}
/*
@Bean
FamilyMember dog() {
return new Dog() {
@@ -38,7 +35,6 @@ public class NestedTestConfigurationDemoTest {
}
};
}
*/
}
@Autowired
@@ -8,13 +8,9 @@ import ru.otus.example.testconfigurationdemo.family.FamilyMember;
import ru.otus.example.testconfigurationdemo.family.parents.Father;
import ru.otus.example.testconfigurationdemo.family.pets.Dog;
@ComponentScan({"ru.otus.example.testconfigurationdemo.family.parents",
"ru.otus.example.testconfigurationdemo.family.childrens"})
/*
@ComponentScan(value = "ru.otus.example.testconfigurationdemo.family",
excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = Dog.class))
*/
@SpringBootConfiguration
public class TestSpringBootConfiguration {
@Bean