merge main

This commit is contained in:
Yansong Zhang
2025-12-19 11:50:45 +08:00
3415 changed files with 323329 additions and 49653 deletions

View File

@@ -0,0 +1,205 @@
# Test Generation Checklist
Use this checklist when generating or reviewing tests for Dify frontend components.
## Pre-Generation
- [ ] Read the component source code completely
- [ ] Identify component type (component, hook, utility, page)
- [ ] Run `pnpm analyze-component <path>` if available
- [ ] Note complexity score and features detected
- [ ] Check for existing tests in the same directory
- [ ] **Identify ALL files in the directory** that need testing (not just index)
## Testing Strategy
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
- [ ] **NEVER generate all tests at once** - process one file at a time
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
- [ ] Create a todo list to track progress before starting
- [ ] For EACH file: write → run test → verify pass → then next
- [ ] **DO NOT proceed** to next file until current one passes
### Path-Level Coverage
- [ ] **Test ALL files** in the assigned directory/path
- [ ] List all components, hooks, utilities that need coverage
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
### Complexity Assessment
- [ ] Run `pnpm analyze-component <path>` for complexity score
- [ ] **Complexity > 50**: Consider refactoring before testing
- [ ] **500+ lines**: Consider splitting before testing
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
## Required Test Sections
### All Components MUST Have
- [ ] **Rendering tests** - Component renders without crashing
- [ ] **Props tests** - Required props, optional props, default values
- [ ] **Edge cases** - null, undefined, empty values, boundaries
### Conditional Sections (Add When Feature Present)
| Feature | Add Tests For |
|---------|---------------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Code Quality Checklist
### Structure
- [ ] Uses `describe` blocks to group related tests
- [ ] Test names follow `should <behavior> when <condition>` pattern
- [ ] AAA pattern (Arrange-Act-Assert) is clear
- [ ] Comments explain complex test scenarios
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
### Queries
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
- [ ] Use `queryBy*` for absence assertions
- [ ] Use `findBy*` for async elements
- [ ] `getByTestId` only as last resort
### Async
- [ ] All async tests use `async/await`
- [ ] `waitFor` wraps async assertions
- [ ] Fake timers properly setup/teardown
- [ ] No floating promises
### TypeScript
- [ ] No `any` types without justification
- [ ] Mock data uses actual types from source
- [ ] Factory functions have proper return types
## Coverage Goals (Per File)
For the current file being tested:
- [ ] 100% function coverage
- [ ] 100% statement coverage
- [ ] >95% branch coverage
- [ ] >95% line coverage
## Post-Generation (Per File)
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch
### False Positives
```typescript
// ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
### State Leakage
```typescript
// ❌ Shared state not reset
let mockState = false
jest.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
mockState = false
})
```
### Async Race Conditions
```typescript
// ❌ Not awaited
it('loads data', () => {
render(<Component />)
expect(screen.getByText('Data')).toBeInTheDocument()
})
// ✅ Properly awaited
it('loads data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### Missing Edge Cases
Always test these scenarios:
- `null` / `undefined` inputs
- Empty strings / arrays / objects
- Boundary values (0, -1, MAX_INT)
- Error states
- Loading states
- Disabled states
## Quick Commands
```bash
# Run specific test
pnpm test -- path/to/file.spec.tsx
# Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx
# Watch mode
pnpm test -- --watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx
# Review existing test
pnpm analyze-component path/to/component.tsx --review
```

View File

@@ -0,0 +1,321 @@
---
name: Dify Frontend Testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
## When to Apply This Skill
Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
- Wants to understand **testing patterns** in the Dify codebase
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Jest | 29.7 | Test runner |
| React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test -- --watch
# Run specific file
pnpm test -- path/to/file.spec.tsx
# Generate coverage report
pnpm test -- --coverage
# Analyze component complexity
pnpm analyze-component <path>
# Review existing test
pnpm analyze-component <path> --review
```
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import Component from './index'
// ✅ Import real project components (DO NOT mock these)
// import Loading from '@/app/components/base/loading'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
jest.mock('@/service/api')
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
// Shared state for mocks (if needed)
let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const props = { title: 'Test' }
// Act
render(<Component {...props} />)
// Assert
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className', () => {
render(<Component className="custom" />)
expect(screen.getByRole('button')).toHaveClass('custom')
})
})
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = jest.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle null data', () => {
render(<Component data={null} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
render(<Component items={[]} />)
expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
})
})
```
## Testing Workflow (CRITICAL)
### ⚠️ Incremental Approach Required
**NEVER generate all test files at once.** For complex components or multi-file directories:
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
1. **Verify before proceeding**: Do NOT continue to next file until current passes
```
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
```
### Complexity-Based Order
Process in this order for multi-file testing:
1. 🟢 Utility functions (simplest)
1. 🟢 Custom hooks
1. 🟡 Simple components (presentational)
1. 🟡 Medium components (state, effects)
1. 🔴 Complex components (API, routing)
1. 🔴 Integration tests (index files - last)
### When to Refactor First
- **Complexity > 50**: Break into smaller pieces before testing
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
### Path-Level Testing (Directory Testing)
When assigned to test a directory/path, test **ALL content** within that path:
- Test all components, hooks, utilities in the directory (not just `index` file)
- Use incremental approach: one file at a time, verify each before proceeding
- Goal: 100% coverage of ALL files in the directory
### Integration Testing First
**Prefer integration testing** when writing tests for a directory:
-**Import real project components** directly (including base components and siblings)
-**Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
-**DO NOT mock** base components (`@/app/components/base/*`)
-**DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
Every test should clearly separate:
- **Arrange**: Setup test data and render component
- **Act**: Perform user actions
- **Assert**: Verify expected outcomes
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:
```typescript
// ❌ Avoid: hardcoded text assertions
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Better: role-based queries
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Better: pattern matching
expect(screen.getByText(/loading/i)).toBeInTheDocument()
```
### 3. Single Behavior Per Test
Each test verifies ONE user-observable behavior:
```typescript
// ✅ Good: One behavior
it('should disable button when loading', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
})
// ❌ Bad: Multiple behaviors
it('should handle loading state', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByText('Loading...')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveClass('loading')
})
```
### 4. Semantic Naming
Use `should <behavior> when <condition>`:
```typescript
it('should show error message when validation fails')
it('should call onSubmit when form is valid')
it('should disable input when isReadOnly is true')
```
## Required Test Scenarios
### Always Required (All Components)
1. **Rendering**: Component renders without crashing
1. **Props**: Required props, optional props, default values
1. **Edge Cases**: null, undefined, empty values, boundary conditions
### Conditional (When Present)
| Feature | Test Focus |
|---------|-----------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | All onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Coverage Goals (Per File)
For each test file generated, aim for:
-**100%** function coverage
-**100%** statement coverage
-**>95%** branch coverage
-**>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
## Authoritative References
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)

View File

@@ -0,0 +1,345 @@
# Async Testing Guide
## Core Async Patterns
### 1. waitFor - Wait for Condition
```typescript
import { render, screen, waitFor } from '@testing-library/react'
it('should load and display data', async () => {
render(<DataComponent />)
// Wait for element to appear
await waitFor(() => {
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
})
})
it('should hide loading spinner after load', async () => {
render(<DataComponent />)
// Wait for element to disappear
await waitFor(() => {
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
})
})
```
### 2. findBy\* - Async Queries
```typescript
it('should show user name after fetch', async () => {
render(<UserProfile />)
// findBy returns a promise, auto-waits up to 1000ms
const userName = await screen.findByText('John Doe')
expect(userName).toBeInTheDocument()
// findByRole with options
const button = await screen.findByRole('button', { name: /submit/i })
expect(button).toBeEnabled()
})
```
### 3. userEvent for Async Interactions
```typescript
import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<Form onSubmit={onSubmit} />)
// userEvent methods are async
await user.type(screen.getByLabelText('Email'), 'test@example.com')
await user.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
})
})
```
## Fake Timers
### When to Use Fake Timers
- Testing components with `setTimeout`/`setInterval`
- Testing debounce/throttle behavior
- Testing animations or delayed transitions
- Testing polling or retry logic
### Basic Fake Timer Setup
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
jest.useFakeTimers()
})
afterEach(() => {
jest.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = jest.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
// Search not called immediately
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
jest.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
})
})
```
### Fake Timers with Async Code
```typescript
it('should retry on failure', async () => {
jest.useFakeTimers()
const fetchData = jest.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
// First call fails
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
// Advance timer for retry
jest.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(2)
expect(screen.getByText('success')).toBeInTheDocument()
})
jest.useRealTimers()
})
```
### Common Fake Timer Utilities
```typescript
// Run all pending timers
jest.runAllTimers()
// Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers()
// Advance by specific time
jest.advanceTimersByTime(1000)
// Get current fake time
jest.now()
// Clear all timers
jest.clearAllTimers()
```
## API Testing Patterns
### Loading → Success → Error States
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should show loading state', () => {
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
render(<DataFetcher />)
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
render(<DataFetcher />)
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
const item1 = await screen.findByText('Item 1')
const item2 = await screen.findByText('Item 2')
expect(item1).toBeInTheDocument()
expect(item2).toBeInTheDocument()
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
})
})
it('should retry on error', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText('Item 1')).toBeInTheDocument()
})
})
})
```
### Testing Mutations
```typescript
it('should submit form and show success', async () => {
const user = userEvent.setup()
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
render(<CreateItemForm />)
await user.type(screen.getByLabelText('Name'), 'New Item')
await user.click(screen.getByRole('button', { name: /create/i }))
// Button should be disabled during submission
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
})
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
})
```
## useEffect Testing
### Testing Effect Execution
```typescript
it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
})
```
### Testing Effect Dependencies
```typescript
it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('1')
})
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('2')
expect(fetchData).toHaveBeenCalledTimes(2)
})
})
```
### Testing Effect Cleanup
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn()
const unsubscribe = jest.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
expect(subscribe).toHaveBeenCalledTimes(1)
unmount()
expect(unsubscribe).toHaveBeenCalledTimes(1)
})
```
## Common Async Pitfalls
### ❌ Don't: Forget to await
```typescript
// Bad - test may pass even if assertion fails
it('should load data', () => {
render(<Component />)
waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
// Good - properly awaited
it('should load data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### ❌ Don't: Use multiple assertions in single waitFor
```typescript
// Bad - if first assertion fails, won't know about second
await waitFor(() => {
expect(screen.getByText('Title')).toBeInTheDocument()
expect(screen.getByText('Description')).toBeInTheDocument()
})
// Good - separate waitFor or use findBy
const title = await screen.findByText('Title')
const description = await screen.findByText('Description')
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
```
### ❌ Don't: Mix fake timers with real async
```typescript
// Bad - fake timers don't work well with real Promises
jest.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers()
render(<Component />)
jest.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@@ -0,0 +1,449 @@
# Common Testing Patterns
## Query Priority
Use queries in this order (most to least preferred):
```typescript
// 1. getByRole - Most recommended (accessibility)
screen.getByRole('button', { name: /submit/i })
screen.getByRole('textbox', { name: /email/i })
screen.getByRole('heading', { level: 1 })
// 2. getByLabelText - Form fields
screen.getByLabelText('Email address')
screen.getByLabelText(/password/i)
// 3. getByPlaceholderText - When no label
screen.getByPlaceholderText('Search...')
// 4. getByText - Non-interactive elements
screen.getByText('Welcome to Dify')
screen.getByText(/loading/i)
// 5. getByDisplayValue - Current input value
screen.getByDisplayValue('current value')
// 6. getByAltText - Images
screen.getByAltText('Company logo')
// 7. getByTitle - Tooltip elements
screen.getByTitle('Close')
// 8. getByTestId - Last resort only!
screen.getByTestId('custom-element')
```
## Event Handling Patterns
### Click Events
```typescript
// Basic click
fireEvent.click(screen.getByRole('button'))
// With userEvent (preferred for realistic interaction)
const user = userEvent.setup()
await user.click(screen.getByRole('button'))
// Double click
await user.dblClick(screen.getByRole('button'))
// Right click
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
```
### Form Input
```typescript
const user = userEvent.setup()
// Type in input
await user.type(screen.getByRole('textbox'), 'Hello World')
// Clear and type
await user.clear(screen.getByRole('textbox'))
await user.type(screen.getByRole('textbox'), 'New value')
// Select option
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
// Check checkbox
await user.click(screen.getByRole('checkbox'))
// Upload file
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
```
### Keyboard Events
```typescript
const user = userEvent.setup()
// Press Enter
await user.keyboard('{Enter}')
// Press Escape
await user.keyboard('{Escape}')
// Keyboard shortcut
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
// Tab navigation
await user.tab()
// Arrow keys
await user.keyboard('{ArrowDown}')
await user.keyboard('{ArrowUp}')
```
## Component State Testing
### Testing State Transitions
```typescript
describe('Counter', () => {
it('should increment count', async () => {
const user = userEvent.setup()
render(<Counter initialCount={0} />)
// Initial state
expect(screen.getByText('Count: 0')).toBeInTheDocument()
// Trigger transition
await user.click(screen.getByRole('button', { name: /increment/i }))
// New state
expect(screen.getByText('Count: 1')).toBeInTheDocument()
})
})
```
### Testing Controlled Components
```typescript
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = jest.fn()
render(<ControlledInput value="" onChange={handleChange} />)
await user.type(screen.getByRole('textbox'), 'a')
expect(handleChange).toHaveBeenCalledWith('a')
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
})
```
## Conditional Rendering Testing
```typescript
describe('ConditionalComponent', () => {
it('should show loading state', () => {
render(<DataDisplay isLoading={true} data={null} />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
})
it('should show error state', () => {
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
})
it('should show data when loaded', () => {
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
expect(screen.getByText('Test')).toBeInTheDocument()
})
it('should show empty state when no data', () => {
render(<DataDisplay isLoading={false} data={[]} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
})
```
## List Rendering Testing
```typescript
describe('ItemList', () => {
const items = [
{ id: '1', name: 'Item 1' },
{ id: '2', name: 'Item 2' },
{ id: '3', name: 'Item 3' },
]
it('should render all items', () => {
render(<ItemList items={items} />)
expect(screen.getAllByRole('listitem')).toHaveLength(3)
items.forEach(item => {
expect(screen.getByText(item.name)).toBeInTheDocument()
})
})
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = jest.fn()
render(<ItemList items={items} onSelect={onSelect} />)
await user.click(screen.getByText('Item 2'))
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle empty list', () => {
render(<ItemList items={[]} />)
expect(screen.getByText(/no items/i)).toBeInTheDocument()
})
})
```
## Modal/Dialog Testing
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.click(screen.getByTestId('modal-overlay'))
expect(handleClose).toHaveBeenCalled()
})
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.keyboard('{Escape}')
expect(handleClose).toHaveBeenCalled()
})
it('should trap focus inside modal', async () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={jest.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
)
// Focus should cycle within modal
await user.tab()
expect(screen.getByText('First')).toHaveFocus()
await user.tab()
expect(screen.getByText('Second')).toHaveFocus()
await user.tab()
expect(screen.getByText('First')).toHaveFocus() // Cycles back
})
})
```
## Form Testing
```typescript
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(onSubmit).toHaveBeenCalledWith({
email: 'test@example.com',
password: 'password123',
})
})
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
})
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
})
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
})
})
})
```
## Data-Driven Tests with test.each
```typescript
describe('StatusBadge', () => {
test.each([
['success', 'bg-green-500'],
['warning', 'bg-yellow-500'],
['error', 'bg-red-500'],
['info', 'bg-blue-500'],
])('should apply correct class for %s status', (status, expectedClass) => {
render(<StatusBadge status={status} />)
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
})
test.each([
{ input: null, expected: 'Unknown' },
{ input: undefined, expected: 'Unknown' },
{ input: '', expected: 'Unknown' },
{ input: 'invalid', expected: 'Unknown' },
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
render(<StatusBadge status={input} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
})
```
## Debugging Tips
```typescript
// Print entire DOM
screen.debug()
// Print specific element
screen.debug(screen.getByRole('button'))
// Log testing playground URL
screen.logTestingPlaygroundURL()
// Pretty print DOM
import { prettyDOM } from '@testing-library/react'
console.log(prettyDOM(screen.getByRole('dialog')))
// Check available roles
import { getRoles } from '@testing-library/react'
console.log(getRoles(container))
```
## Common Mistakes to Avoid
### ❌ Don't Use Implementation Details
```typescript
// Bad - testing implementation
expect(component.state.isOpen).toBe(true)
expect(wrapper.find('.internal-class').length).toBe(1)
// Good - testing behavior
expect(screen.getByRole('dialog')).toBeInTheDocument()
```
### ❌ Don't Forget Cleanup
```typescript
// Bad - may leak state between tests
it('test 1', () => {
render(<Component />)
})
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
jest.clearAllMocks()
})
```
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
```typescript
// ❌ Bad - hardcoded strings are brittle
expect(screen.getByText('Submit Form')).toBeInTheDocument()
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Good - role-based queries (most semantic)
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Good - pattern matching (flexible)
expect(screen.getByText(/submit/i)).toBeInTheDocument()
expect(screen.getByText(/loading/i)).toBeInTheDocument()
// ✅ Good - test behavior, not exact UI text
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByRole('alert')).toBeInTheDocument()
```
**Why prefer black-box assertions?**
- Text content may change (i18n, copy updates)
- Role-based queries test accessibility
- Pattern matching is resilient to minor changes
- Tests focus on behavior, not implementation details
### ❌ Don't Assert on Absence Without Query
```typescript
// Bad - throws if not found
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
// Good - use queryBy for absence assertions
expect(screen.queryByText('Error')).not.toBeInTheDocument()
```

View File

@@ -0,0 +1,523 @@
# Domain-Specific Component Testing
This guide covers testing patterns for Dify's domain-specific components.
## Workflow Components (`workflow/`)
Workflow components handle node configuration, data flow, and graph operations.
### Key Test Areas
1. **Node Configuration**
1. **Data Validation**
1. **Variable Passing**
1. **Edge Connections**
1. **Error Handling**
### Example: Node Configuration Panel
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: jest.fn(),
handleNodeDelete: jest.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
jest.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
}
})
describe('Node Configuration', () => {
it('should render node type selector', () => {
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
})
it('should update node config on change', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
node.id,
expect.objectContaining({ model: 'gpt-4' })
)
})
})
describe('Data Validation', () => {
it('should show error for invalid input', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'code' })
render(<NodeConfigPanel node={node} />)
// Enter invalid code
const codeInput = screen.getByLabelText(/code/i)
await user.clear(codeInput)
await user.type(codeInput, 'invalid syntax {{{')
await waitFor(() => {
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
})
})
it('should validate required fields', async () => {
const node = createMockNode({ type: 'http', data: { url: '' } })
render(<NodeConfigPanel node={node} />)
fireEvent.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
})
})
})
describe('Variable Passing', () => {
it('should display available variables from upstream nodes', () => {
const upstreamNode = createMockNode({
id: 'node-1',
type: 'start',
data: { outputs: [{ name: 'user_input', type: 'string' }] },
})
const currentNode = createMockNode({
id: 'node-2',
type: 'llm',
})
mockWorkflowStore.nodes = [upstreamNode, currentNode]
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
render(<NodeConfigPanel node={currentNode} />)
// Variable selector should show upstream variables
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
expect(screen.getByText('user_input')).toBeInTheDocument()
})
it('should insert variable into prompt template', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
// Click variable button
await user.click(screen.getByRole('button', { name: /insert variable/i }))
await user.click(screen.getByText('user_input'))
const promptInput = screen.getByLabelText(/prompt/i)
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
})
})
})
```
## Dataset Components (`dataset/`)
Dataset components handle file uploads, data display, and search/filter operations.
### Key Test Areas
1. **File Upload**
1. **File Type Validation**
1. **Pagination**
1. **Search & Filtering**
1. **Data Format Handling**
### Example: Document Uploader
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(),
parseDocument: jest.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService>
describe('DocumentUploader', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = jest.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
await waitFor(() => {
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
expect.any(FormData)
)
})
})
it('should reject invalid file types', async () => {
const user = userEvent.setup()
render(<DocumentUploader />)
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
})
it('should show upload progress', async () => {
const user = userEvent.setup()
// Mock upload with progress
mockedService.uploadDocument.mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ id: 'doc-1' }), 100)
})
})
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
expect(screen.getByRole('progressbar')).toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
})
})
})
describe('Error Handling', () => {
it('should handle upload failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
})
})
it('should allow retry after failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ id: 'doc-1' })
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
})
})
})
})
```
### Example: Document List with Pagination
```typescript
describe('DocumentList', () => {
describe('Pagination', () => {
it('should load first page on mount', async () => {
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
})
it('should navigate to next page', async () => {
const user = userEvent.setup()
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '11', name: 'Doc 11' }],
total: 50,
page: 2,
pageSize: 10,
})
await user.click(screen.getByRole('button', { name: /next/i }))
await waitFor(() => {
expect(screen.getByText('Doc 11')).toBeInTheDocument()
})
})
})
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
jest.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
jest.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
'ds-1',
expect.objectContaining({ search: 'test query' })
)
})
jest.useRealTimers()
})
})
})
```
## Configuration Components (`app/configuration/`, `config/`)
Configuration components handle forms, validation, and data persistence.
### Key Test Areas
1. **Form Validation**
1. **Save/Reset**
1. **Required vs Optional Fields**
1. **Configuration Persistence**
1. **Error Feedback**
### Example: App Configuration Form
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(),
getAppConfig: jest.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService>
describe('AppConfigForm', () => {
const defaultConfig = {
name: 'My App',
description: '',
icon: 'default',
openingStatement: '',
}
beforeEach(() => {
jest.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})
describe('Form Validation', () => {
it('should require app name', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Clear name field
await user.clear(screen.getByLabelText(/name/i))
await user.click(screen.getByRole('button', { name: /save/i }))
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
})
it('should validate name length', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
})
// Enter very long name
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
})
it('should allow empty optional fields', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Leave description empty (optional)
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalled()
})
})
})
describe('Save/Reset Functionality', () => {
it('should save configuration', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Updated App')
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
'app-1',
expect.objectContaining({ name: 'Updated App' })
)
})
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
})
it('should reset to default values', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
// Reset
await user.click(screen.getByRole('button', { name: /reset/i }))
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
it('should show unsaved changes warning', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.type(screen.getByLabelText(/name/i), ' Updated')
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error on save failure', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
})
})
})
})
```

View File

@@ -0,0 +1,363 @@
# Mocking Guide for Dify Frontend Tests
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
### What TO Mock
Only mock these categories:
1. **API services** (`@/service/*`) - Network calls
1. **Complex context providers** - When setup is too difficult
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
## Mock Placement
| Location | Purpose |
|----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
| Test file | Test-specific mocks, inline with `jest.mock()` |
## Essential Mocks
### 1. i18n (Auto-loaded via Shared Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
### 2. Next.js Router
```typescript
const mockPush = jest.fn()
const mockReplace = jest.fn()
jest.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: jest.fn(),
prefetch: jest.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should navigate on click', () => {
render(<Component />)
fireEvent.click(screen.getByRole('button'))
expect(mockPush).toHaveBeenCalledWith('/expected-path')
})
})
```
### 3. Portal Components (with Shared State)
```typescript
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
},
PortalToFollowElemContent: ({ children }: any) => {
// ✅ Matches actual: returns null when portal is closed
if (!mockPortalOpenState) return null
return <div data-testid="portal-content">{children}</div>
},
PortalToFollowElemTrigger: ({ children }: any) => (
<div data-testid="portal-trigger">{children}</div>
),
}))
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
```
### 4. API Service Mocks
```typescript
import * as api from '@/service/api'
jest.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api>
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
render(<Component />)
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
})
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<Component />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 5. HTTP Mocking with Nock
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText('dify')).toBeInTheDocument()
})
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 6. Context Providers
```typescript
import { ProviderContext } from '@/context/provider-context'
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
describe('Component with Context', () => {
it('should render for free plan', () => {
const mockContext = createMockPlan('sandbox')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.getByText('Upgrade')).toBeInTheDocument()
})
it('should render for pro plan', () => {
const mockContext = createMockPlan('professional')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
})
})
```
### 7. SWR / React Query
```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
})
const renderWithQueryClient = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>
)
}
```
## Mock Best Practices
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
1. **Import actual types** for type safety
1. **Reset shared mock state** in `beforeEach`
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?
│ └─ YES → Prefer importing real component
│ Only mock if setup is extremely complex
├─ Is it an API service (@/service/*)?
│ └─ YES → Mock it
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern
```typescript
// __mocks__/data-factories.ts
import type { User, Project } from '@/types'
export const createMockUser = (overrides: Partial<User> = {}): User => ({
id: 'user-1',
name: 'Test User',
email: 'test@example.com',
role: 'member',
createdAt: new Date().toISOString(),
...overrides,
})
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
id: 'project-1',
name: 'Test Project',
description: 'A test project',
owner: createMockUser(),
members: [],
createdAt: new Date().toISOString(),
...overrides,
})
// Usage in tests
it('should display project owner', () => {
const project = createMockProject({
owner: createMockUser({ name: 'John Doe' }),
})
render(<ProjectCard project={project} />)
expect(screen.getByText('John Doe')).toBeInTheDocument()
})
```

View File

@@ -0,0 +1,269 @@
# Testing Workflow Guide
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
### Why Incremental?
| Batch Approach (❌) | Incremental Approach (✅) |
|---------------------|---------------------------|
| Generate 5+ tests at once | Generate 1 test at a time |
| Run tests only at the end | Run test immediately after each file |
| Multiple failures compound | Single point of failure, easy to debug |
| Hard to identify root cause | Clear cause-effect relationship |
| Mock issues affect many files | Mock issues caught early |
| Messy git history | Clean, atomic commits possible |
## Single File Workflow
When testing a **single component, hook, or utility**:
```
1. Read source code completely
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
## Directory/Multi-File Workflow (MUST FOLLOW)
When testing a **directory or multiple files**, follow this strict workflow:
### Step 1: Analyze and Plan
1. **List all files** that need tests in the directory
1. **Categorize by complexity**:
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
- 🟡 **Medium**: Components with state, effects, or event handlers
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
1. **Order by dependency**: Test dependencies before dependents
1. **Create a todo list** to track progress
### Step 2: Determine Processing Order
Process files in this recommended order:
```
1. Utility functions (simplest, no React)
2. Custom hooks (isolated logic)
3. Simple presentational components (few/no props)
4. Medium complexity components (state, effects)
5. Complex components (API, routing, many deps)
6. Container/index components (integration tests - last)
```
**Rationale**:
- Simpler files help establish mock patterns
- Hooks used by components should be tested first
- Integration tests (index files) depend on child components working
### Step 3: Process Each File Incrementally
**For EACH file in the ordered list:**
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
└─────────────────────────────────────────────┘
```
**DO NOT proceed to the next file until the current one passes.**
### Step 4: Final Verification
After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test -- path/to/directory/
# Check coverage
pnpm test -- --coverage path/to/directory/
```
## Component Complexity Guidelines
Use `pnpm analyze-component <path>` to assess complexity before testing.
### 🔴 Very Complex Components (Complexity > 50)
**Consider refactoring BEFORE testing:**
- Break component into smaller, testable pieces
- Extract complex logic into custom hooks
- Separate container and presentational layers
**If testing as-is:**
- Use integration tests for complex workflows
- Use `test.each()` for data-driven testing
- Multiple `describe` blocks for organization
- Consider testing major sections separately
### 🟡 Medium Complexity (Complexity 30-50)
- Group related tests in `describe` blocks
- Test integration scenarios between internal parts
- Focus on state transitions and side effects
- Use helper functions to reduce test complexity
### 🟢 Simple Components (Complexity < 30)
- Standard test structure
- Focus on props, rendering, and edge cases
- Usually straightforward to test
### 📏 Large Files (500+ lines)
Regardless of complexity score:
- **Strongly consider refactoring** before testing
- If testing as-is, test major sections separately
- Create helper functions for test setup
- May need multiple test files
## Todo List Format
When testing multiple files, use a todo list like this:
```
Testing: path/to/directory/
Ordered by complexity (simple → complex):
☐ utils/helper.ts [utility, simple]
☐ hooks/use-custom-hook.ts [hook, simple]
☐ empty-state.tsx [component, simple]
☐ item-card.tsx [component, medium]
☐ list.tsx [component, complex]
☐ index.tsx [integration]
Progress: 0/6 complete
```
Update status as you complete each:
- ☐ → ⏳ (in progress)
- ⏳ → ✅ (complete and verified)
- ⏳ → ❌ (blocked, needs attention)
## When to Stop and Verify
**Always run tests after:**
- Completing a test file
- Making changes to fix a failure
- Modifying shared mocks
- Updating test utilities or helpers
**Signs you should pause:**
- More than 2 consecutive test failures
- Mock-related errors appearing
- Unclear why a test is failing
- Test passing but coverage unexpectedly low
## Common Pitfalls to Avoid
### ❌ Don't: Generate Everything First
```
# BAD: Writing all files then testing
Write component-a.spec.tsx
Write component-b.spec.tsx
Write component-c.spec.tsx
Write component-d.spec.tsx
Run pnpm test ← Multiple failures, hard to debug
```
### ✅ Do: Verify Each Step
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅
...continue...
```
### ❌ Don't: Skip Verification for "Simple" Components
Even simple components can have:
- Import errors
- Missing mock setup
- Incorrect assumptions about props
**Always verify, regardless of perceived simplicity.**
### ❌ Don't: Continue When Tests Fail
Failing tests compound:
- A mock issue in file A affects files B, C, D
- Fixing A later requires revisiting all dependent tests
- Time wasted on debugging cascading failures
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
When using Claude for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress
Example prompt:
```
Test all components in `path/to/directory/`.
First, analyze the directory and create a todo list ordered by complexity.
Then, process ONE file at a time, waiting for my confirmation that tests pass
before proceeding to the next.
```
## Summary Checklist
Before starting multi-file testing:
- [ ] Listed all files needing tests
- [ ] Ordered by complexity (simple → complex)
- [ ] Created todo list for tracking
- [ ] Understand dependencies between files
During testing:
- [ ] Processing ONE file at a time
- [ ] Running tests after EACH file
- [ ] Fixing failures BEFORE proceeding
- [ ] Updating todo list progress
After completion:
- [ ] All individual tests pass
- [ ] Full directory test run passes
- [ ] Coverage goals met
- [ ] Todo list shows all complete

View File

@@ -0,0 +1,296 @@
/**
* Test Template for React Components
*
* WHY THIS STRUCTURE?
* - Organized sections make tests easy to navigate and maintain
* - Mocks at top ensure consistent test isolation
* - Factory functions reduce duplication and improve readability
* - describe blocks group related scenarios for better debugging
*
* INSTRUCTIONS:
* 1. Replace `ComponentName` with your component name
* 2. Update import path
* 3. Add/remove test sections based on component features (use analyze-component)
* 4. Follow AAA pattern: Arrange → Act → Assert
*
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
*/
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
// import ComponentName from './index'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn()
// jest.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
// parent and child mocks to correctly simulate open/close behavior
// let mockOpenState = false
// ============================================================================
// Test Data Factories
// ============================================================================
// WHY FACTORIES?
// - Avoid hard-coded test data scattered across tests
// - Easy to create variations with overrides
// - Type-safe when using actual types from source
// - Single source of truth for default test values
// const createMockProps = (overrides = {}) => ({
// // Default props that make component render successfully
// ...overrides,
// })
// const createMockItem = (overrides = {}) => ({
// id: 'item-1',
// name: 'Test Item',
// ...overrides,
// })
// ============================================================================
// Test Helpers
// ============================================================================
// const renderComponent = (props = {}) => {
// return render(<ComponentName {...createMockProps(props)} />)
// }
// ============================================================================
// Tests
// ============================================================================
describe('ComponentName', () => {
// WHY beforeEach with clearAllMocks?
// - Ensures each test starts with clean slate
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
jest.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
// --------------------------------------------------------------------------
// Rendering Tests (REQUIRED - Every component MUST have these)
// --------------------------------------------------------------------------
// WHY: Catches import errors, missing providers, and basic render issues
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange - Setup data and mocks
// const props = createMockProps()
// Act - Render the component
// render(<ComponentName {...props} />)
// Assert - Verify expected output
// Prefer getByRole for accessibility; it's what users "see"
// expect(screen.getByRole('...')).toBeInTheDocument()
})
it('should render with default props', () => {
// WHY: Verifies component works without optional props
// render(<ComponentName />)
// expect(screen.getByText('...')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props Tests (REQUIRED - Every component MUST test prop behavior)
// --------------------------------------------------------------------------
// WHY: Props are the component's API contract. Test them thoroughly.
describe('Props', () => {
it('should apply custom className', () => {
// WHY: Common pattern in Dify - components should merge custom classes
// render(<ComponentName className="custom-class" />)
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
})
it('should use default values for optional props', () => {
// WHY: Verifies TypeScript defaults work at runtime
// render(<ComponentName />)
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
})
})
// --------------------------------------------------------------------------
// User Interactions (if component has event handlers - on*, handle*)
// --------------------------------------------------------------------------
// WHY: Event handlers are core functionality. Test from user's perspective.
describe('User Interactions', () => {
it('should call onClick when clicked', async () => {
// WHY userEvent over fireEvent?
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = jest.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
//
// expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = jest.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
//
// expect(handleChange).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// State Management (if component uses useState/useReducer)
// --------------------------------------------------------------------------
// WHY: Test state through observable UI changes, not internal state values
describe('State Management', () => {
it('should update state on interaction', async () => {
// WHY test via UI, not state?
// - State is implementation detail; UI is what users see
// - If UI works correctly, state must be correct
// const user = userEvent.setup()
// render(<ComponentName />)
//
// // Initial state - verify what user sees
// expect(screen.getByText('Initial')).toBeInTheDocument()
//
// // Trigger state change via user action
// await user.click(screen.getByRole('button'))
//
// // New state - verify UI updated
// expect(screen.getByText('Updated')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {
it('should show loading state', () => {
// WHY never-resolving promise?
// - Keeps component in loading state for assertion
// - Alternative: use fake timers
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
// render(<ComponentName />)
//
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
it('should show data on success', async () => {
// WHY waitFor?
// - Component updates asynchronously after fetch resolves
// - waitFor retries assertion until it passes or times out
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText('Item 1')).toBeInTheDocument()
// })
})
it('should show error on failure', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText(/error/i)).toBeInTheDocument()
// })
})
})
// --------------------------------------------------------------------------
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
// --------------------------------------------------------------------------
// WHY: Real-world data is messy. Components must handle:
// - Null/undefined from API failures or optional fields
// - Empty arrays/strings from user clearing data
// - Boundary values (0, MAX_INT, special characters)
describe('Edge Cases', () => {
it('should handle null value', () => {
// WHY test null specifically?
// - API might return null for missing data
// - Prevents "Cannot read property of null" in production
// render(<ComponentName value={null} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle undefined value', () => {
// WHY test undefined separately from null?
// - TypeScript treats them differently
// - Optional props are undefined, not null
// render(<ComponentName value={undefined} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
// WHY: Empty state often needs special UI (e.g., "No items yet")
// render(<ComponentName items={[]} />)
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
it('should handle empty string', () => {
// WHY: Empty strings are truthy in JS but visually empty
// render(<ComponentName text="" />)
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Accessibility (optional but recommended for Dify's enterprise users)
// --------------------------------------------------------------------------
// WHY: Dify has enterprise customers who may require accessibility compliance
describe('Accessibility', () => {
it('should have accessible name', () => {
// WHY getByRole with name?
// - Tests that screen readers can identify the element
// - Enforces proper labeling practices
// render(<ComponentName label="Test Label" />)
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
})
it('should support keyboard navigation', async () => {
// WHY: Some users can't use a mouse
// const user = userEvent.setup()
// render(<ComponentName />)
//
// await user.tab()
// expect(screen.getByRole('button')).toHaveFocus()
})
})
})

View File

@@ -0,0 +1,207 @@
/**
* Test Template for Custom Hooks
*
* Instructions:
* 1. Replace `useHookName` with your hook name
* 2. Update import path
* 3. Add/remove test sections based on hook features
*/
import { renderHook, act, waitFor } from '@testing-library/react'
// import { useHookName } from './use-hook-name'
// ============================================================================
// Mocks
// ============================================================================
// API services (if hook fetches data)
// jest.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// ============================================================================
// Test Helpers
// ============================================================================
// Wrapper for hooks that need context
// const createWrapper = (contextValue = {}) => {
// return ({ children }: { children: React.ReactNode }) => (
// <SomeContext.Provider value={contextValue}>
// {children}
// </SomeContext.Provider>
// )
// }
// ============================================================================
// Tests
// ============================================================================
describe('useHookName', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// --------------------------------------------------------------------------
// Initial State
// --------------------------------------------------------------------------
describe('Initial State', () => {
it('should return initial state', () => {
// const { result } = renderHook(() => useHookName())
//
// expect(result.current.value).toBe(initialValue)
// expect(result.current.isLoading).toBe(false)
})
it('should accept initial value from props', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
//
// expect(result.current.value).toBe('custom')
})
})
// --------------------------------------------------------------------------
// State Updates
// --------------------------------------------------------------------------
describe('State Updates', () => {
it('should update value when setValue is called', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(result.current.value).toBe('new value')
})
it('should reset to initial value', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
//
// act(() => {
// result.current.setValue('changed')
// })
// expect(result.current.value).toBe('changed')
//
// act(() => {
// result.current.reset()
// })
// expect(result.current.value).toBe('initial')
})
})
// --------------------------------------------------------------------------
// Async Operations
// --------------------------------------------------------------------------
describe('Async Operations', () => {
it('should fetch data on mount', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result } = renderHook(() => useHookName())
//
// // Initially loading
// expect(result.current.isLoading).toBe(true)
//
// // Wait for data
// await waitFor(() => {
// expect(result.current.isLoading).toBe(false)
// })
//
// expect(result.current.data).toEqual({ data: 'test' })
})
it('should handle fetch error', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
//
// const { result } = renderHook(() => useHookName())
//
// await waitFor(() => {
// expect(result.current.error).toBeTruthy()
// })
//
// expect(result.current.error?.message).toBe('Network error')
})
it('should refetch when dependency changes', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result, rerender } = renderHook(
// ({ id }) => useHookName(id),
// { initialProps: { id: '1' } }
// )
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
// })
//
// rerender({ id: '2' })
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
// })
})
})
// --------------------------------------------------------------------------
// Side Effects
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = jest.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(callback).toHaveBeenCalledWith('new value')
})
it('should cleanup on unmount', () => {
// const cleanup = jest.fn()
// jest.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//
// expect(window.addEventListener).toHaveBeenCalled()
//
// unmount()
//
// expect(window.removeEventListener).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle null input', () => {
// const { result } = renderHook(() => useHookName(null))
//
// expect(result.current.value).toBeNull()
})
it('should handle rapid updates', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('1')
// result.current.setValue('2')
// result.current.setValue('3')
// })
//
// expect(result.current.value).toBe('3')
})
})
// --------------------------------------------------------------------------
// With Context (if hook uses context)
// --------------------------------------------------------------------------
describe('With Context', () => {
it('should use context value', () => {
// const wrapper = createWrapper({ someValue: 'context-value' })
// const { result } = renderHook(() => useHookName(), { wrapper })
//
// expect(result.current.contextValue).toBe('context-value')
})
})
})

View File

@@ -0,0 +1,154 @@
/**
* Test Template for Utility Functions
*
* Instructions:
* 1. Replace `utilityFunction` with your function name
* 2. Update import path
* 3. Use test.each for data-driven tests
*/
// import { utilityFunction } from './utility'
// ============================================================================
// Tests
// ============================================================================
describe('utilityFunction', () => {
// --------------------------------------------------------------------------
// Basic Functionality
// --------------------------------------------------------------------------
describe('Basic Functionality', () => {
it('should return expected result for valid input', () => {
// expect(utilityFunction('input')).toBe('expected-output')
})
it('should handle multiple arguments', () => {
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
})
})
// --------------------------------------------------------------------------
// Data-Driven Tests
// --------------------------------------------------------------------------
describe('Input/Output Mapping', () => {
test.each([
// [input, expected]
['input1', 'output1'],
['input2', 'output2'],
['input3', 'output3'],
])('should return %s for input %s', (input, expected) => {
// expect(utilityFunction(input)).toBe(expected)
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle empty string', () => {
// expect(utilityFunction('')).toBe('')
})
it('should handle null', () => {
// expect(utilityFunction(null)).toBe(null)
// or
// expect(() => utilityFunction(null)).toThrow()
})
it('should handle undefined', () => {
// expect(utilityFunction(undefined)).toBe(undefined)
// or
// expect(() => utilityFunction(undefined)).toThrow()
})
it('should handle empty array', () => {
// expect(utilityFunction([])).toEqual([])
})
it('should handle empty object', () => {
// expect(utilityFunction({})).toEqual({})
})
})
// --------------------------------------------------------------------------
// Boundary Conditions
// --------------------------------------------------------------------------
describe('Boundary Conditions', () => {
it('should handle minimum value', () => {
// expect(utilityFunction(0)).toBe(0)
})
it('should handle maximum value', () => {
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
})
it('should handle negative numbers', () => {
// expect(utilityFunction(-1)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Type Coercion (if applicable)
// --------------------------------------------------------------------------
describe('Type Handling', () => {
it('should handle numeric string', () => {
// expect(utilityFunction('123')).toBe(123)
})
it('should handle boolean', () => {
// expect(utilityFunction(true)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Error Cases
// --------------------------------------------------------------------------
describe('Error Handling', () => {
it('should throw for invalid input', () => {
// expect(() => utilityFunction('invalid')).toThrow('Error message')
})
it('should throw with specific error type', () => {
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
})
})
// --------------------------------------------------------------------------
// Complex Objects (if applicable)
// --------------------------------------------------------------------------
describe('Object Handling', () => {
it('should preserve object structure', () => {
// const input = { a: 1, b: 2 }
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
})
it('should handle nested objects', () => {
// const input = { nested: { deep: 'value' } }
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
})
it('should not mutate input', () => {
// const input = { a: 1 }
// const inputCopy = { ...input }
// utilityFunction(input)
// expect(input).toEqual(inputCopy)
})
})
// --------------------------------------------------------------------------
// Array Handling (if applicable)
// --------------------------------------------------------------------------
describe('Array Handling', () => {
it('should process all elements', () => {
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
})
it('should handle single element array', () => {
// expect(utilityFunction([1])).toEqual([2])
})
it('should preserve order', () => {
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
})
})
})

5
.coveragerc Normal file
View File

@@ -0,0 +1,5 @@
[run]
omit =
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

View File

@@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@@ -11,7 +11,7 @@
"nodeGypDependencies": true,
"version": "lts"
},
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
"ghcr.io/devcontainers-extra/features/npm-package:1": {
"package": "typescript",
"version": "latest"
},

View File

@@ -6,11 +6,10 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
source /home/vscode/.bashrc

View File

@@ -29,7 +29,7 @@ trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation
# Set default charset
[*.{js,tsx}]
[*.{js,jsx,ts,tsx,mjs}]
indent_style = space
indent_size = 2

240
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1,240 @@
# CODEOWNERS
# This file defines code ownership for the Dify project.
# Each line is a file pattern followed by one or more owners.
# Owners can be @username, @org/team-name, or email addresses.
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 @MRZHUH
# Frontend
web/ @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh

View File

@@ -1,5 +1,8 @@
blank_issues_enabled: false
contact_links:
- name: "\U0001F510 Security Vulnerabilities"
url: "https://github.com/langgenius/dify/security/advisories/new"
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
- name: "\U0001F4A1 Model Providers & Plugins"
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.

View File

@@ -1,8 +1,6 @@
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
name: "✨ Refactor or Chore"
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Refactor/Chore] "
body:
- type: checkboxes
attributes:
@@ -11,7 +9,7 @@ body:
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
@@ -25,14 +23,14 @@ body:
id: description
attributes:
label: Description
placeholder: "Describe the refactor you are proposing."
placeholder: "Describe the refactor or chore you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor is necessary."
placeholder: "Explain why this refactor or chore is necessary."
validations:
required: false
- type: textarea

View File

@@ -1,13 +0,0 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@@ -39,25 +39,11 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -76,7 +62,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
sandbox
ssrf_proxy
@@ -85,11 +71,34 @@ jobs:
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
- name: Run API Tests
env:
STORAGE_TYPE: opendal
OPENDAL_SCHEME: fs
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@@ -2,6 +2,8 @@ name: autofix.ci
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: read
@@ -11,23 +13,34 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
- uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- run: |
cd api
uv sync --dev
# fmt first to avoid line too long
uv run ruff format ..
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format ..
- name: count migration progress
run: |
cd api
./cnt_base.sh
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
# ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union
@@ -45,14 +58,15 @@ jobs:
pattern: $T
fix: $T | None
EOF
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx mdformat .
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
- name: Install pnpm
uses: pnpm/action-setup@v4
@@ -65,7 +79,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
working-directory: ./web
@@ -73,7 +87,6 @@ jobs:
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
run: pnpm exec oxlint --config .oxlintrc.json --fix .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@@ -4,8 +4,7 @@ on:
push:
branches:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "deploy/**"
- "build/**"
- "release/e-*"
- "hotfix/**"

View File

@@ -8,7 +8,7 @@ concurrency:
cancel-in-progress: true
jobs:
db-migration-test:
db-migration-test-postgres:
runs-on: ubuntu-latest
steps:
@@ -45,7 +45,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
- name: Prepare configs
@@ -57,3 +57,60 @@ jobs:
env:
DEBUG: true
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_mysql
redis
- name: Prepare configs for MySQL
run: |
cd api
cp .env.example .env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
- name: Run DB Migration
env:
DEBUG: true
run: uv run --directory api flask upgrade-db

View File

@@ -18,7 +18,7 @@ jobs:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -1,4 +1,4 @@
name: Deploy RAG Dev
name: Deploy Trigger Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/trigger-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -1,6 +1,7 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
@@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@@ -0,0 +1,21 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -90,7 +90,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
@@ -103,6 +103,11 @@ jobs:
run: |
pnpm run lint
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest

View File

@@ -20,22 +20,22 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .ts)
file_args="$file_args --file=$filename"
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
@@ -55,7 +55,7 @@ jobs:
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
@@ -77,12 +77,15 @@ jobs:
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files and type definitions based on en-US changes
title: 'chore: translate i18n files and update type definitions'
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@@ -51,13 +51,13 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Store (TiDB)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: docker/tidb/docker-compose.yaml
services: |
tidb
tiflash
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
# compose-file: docker/tidb/docker-compose.yaml
# services: |
# tidb
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2
@@ -83,8 +83,8 @@ jobs:
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
@@ -21,14 +22,7 @@ jobs:
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
@@ -36,23 +30,347 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test
run: |
pnpm exec jest \
--ci \
--runInBand \
--coverage \
--passWithNoTests
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Jest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

12
.gitignore vendored
View File

@@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# *db files
*.db
# Distribution / packaging
.Python
build/
@@ -97,6 +100,7 @@ __pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat-schedule.db
celerybeat.pid
# SageMath parsed files
@@ -182,7 +186,10 @@ docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/volumes/iris/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
@@ -234,4 +241,7 @@ scripts/stress-test/reports/
# mcp
.playwright-mcp/
.serena/
.serena/
# settings
*.local.json

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
22.11.0

View File

@@ -8,8 +8,7 @@
"module": "flask",
"env": {
"FLASK_APP": "app.py",
"FLASK_ENV": "development",
"GEVENT_SUPPORT": "True"
"FLASK_ENV": "development"
},
"args": [
"run",
@@ -28,9 +27,7 @@
"type": "debugpy",
"request": "launch",
"module": "celery",
"env": {
"GEVENT_SUPPORT": "True"
},
"env": {},
"args": [
"-A",
"app.celery",
@@ -40,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,generation,mail,ops_trace",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@@ -4,84 +4,51 @@
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
The codebase consists of:
The codebase is split into:
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Development Commands
## Backend Workflow
### Backend (API)
- Run backend CLI commands through `uv run --project api <command>`.
All Python commands must be prefixed with `uv run --project api`:
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
```bash
# Start development servers
./dev/start-api # Start API server
./dev/start-worker # Start Celery worker
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
# Run tests
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
- Integration tests are CI-only and are not expected to run in the local environment.
# Code quality
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)
## Frontend Workflow
```bash
cd web
pnpm lint # Run ESLint
pnpm eslint-fix # Fix ESLint issues
pnpm test # Run Jest tests
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
## Testing Guidelines
## Testing & Quality Practices
### Backend Testing
- Follow TDD: red → green → refactor.
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
- Write self-documenting code; only add comments that explain intent.
- Use `pytest` for all backend tests
- Write tests first (TDD approach)
- Test structure: Arrange-Act-Assert
## Language Style
## Code Style Requirements
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
### Python
## General Practices
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
- Prefer editing existing files; add new documentation only when requested.
- Inject dependencies through constructors and preserve clean architecture boundaries.
- Handle errors with domain-specific exceptions at the correct layer.
### TypeScript/JavaScript
## Project Conventions
- Strict TypeScript configuration
- ESLint with Prettier integration
- Avoid `any` type
## Important Notes
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
- **Comments**: Only write meaningful comments that explain "why", not "what"
- **File Creation**: Always prefer editing existing files over creating new ones
- **Documentation**: Don't create documentation files unless explicitly requested
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
## Common Development Tasks
### Adding a New API Endpoint
1. Create controller in `/api/controllers/`
1. Add service logic in `/api/services/`
1. Update routes in controller's `__init__.py`
1. Write tests in `/api/tests/`
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
- Backend architecture adheres to DDD and Clean Architecture principles.
- Async work runs through Celery with Redis as the broker.
- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.

View File

@@ -77,6 +77,8 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.

View File

@@ -26,7 +26,6 @@ prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
@@ -71,6 +70,11 @@ type-check:
@uv run --directory api --dev basedpyright
@echo "✅ Type check complete"
test:
@echo "🧪 Running backend unit tests..."
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
@echo "✅ Tests complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@@ -120,6 +124,7 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@@ -129,4 +134,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test

View File

@@ -36,22 +36,28 @@
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
</p>
<p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
@@ -63,7 +69,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
> - CPU >= 2 Core
> - RAM >= 4 GiB
</br>
<br/>
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
@@ -109,15 +115,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
## Using Dify
- **Cloud </br>**
- **Cloud <br/>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>**
- **Self-hosting Dify Community Edition<br/>**
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations</br>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
@@ -129,8 +135,31 @@ Star Dify on GitHub and be instantly notified of new releases.
## Advanced Setup
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
```bash
# In your .env file
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
SUGGESTED_QUESTIONS_MAX_TOKENS=512
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
```
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)

View File

@@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
# Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
@@ -69,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# PostgreSQL database configuration
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
@@ -156,9 +162,11 @@ SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). Leading dots are optional.
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@@ -168,6 +176,18 @@ WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
OCEANBASE_FULLTEXT_PARSER=ik
SEEKDB_MEMORY_LIMIT=2G
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
@@ -334,14 +354,14 @@ LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
@@ -359,6 +379,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
@@ -425,10 +451,13 @@ CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_EXECUTION_CONNECT_TIMEOUT=10
CODE_EXECUTION_READ_TIMEOUT=60
CODE_EXECUTION_WRITE_TIMEOUT=10
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
@@ -445,6 +474,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@@ -500,7 +532,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
WORKFLOW_LOG_CLEANUP_ENABLED=false
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
@@ -508,8 +540,28 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@@ -522,6 +574,12 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration
POSITION_TOOL_PINS=
@@ -593,3 +651,47 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Suggested Questions After Answer Configuration
# These environment variables allow customization of the suggested questions feature
#
# Custom prompt for generating suggested questions (optional)
# If not set, uses the default prompt that generates 3 questions under 20 characters each
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
# SUGGESTED_QUESTIONS_PROMPT=
# Maximum number of tokens for suggested questions generation (default: 256)
# Adjust this value for longer questions or more questions
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
# Temperature for suggested questions generation (default: 0.0)
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
# SUGGESTED_QUESTIONS_TEMPERATURE=0
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed CSV file size for annotation import in megabytes
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
#Maximum number of annotation records allowed in a single import
ANNOTATION_IMPORT_MAX_RECORDS=10000
# Minimum number of annotation records required in a single import
ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@@ -16,6 +16,7 @@ layers =
graph
nodes
node_events
runtime
entities
containers =
core.workflow

View File

@@ -36,17 +36,20 @@ select = [
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum,
"S110", # disallow the try-except-pass pattern.
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
"S311", # suspicious-non-cryptographic-random-usage,
]
ignore = [
@@ -81,7 +84,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]
@@ -92,18 +94,16 @@ ignore = [
"configs/*" = [
"N802", # invalid-function-name
]
"core/model_runtime/callbacks/base_callback.py" = [
"T201",
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"core/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
[lint.pyflakes]

View File

@@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
]
}
]

62
api/AGENTS.md Normal file
View File

@@ -0,0 +1,62 @@
# Agent Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
______________________________________________________________________
## Platform Foundations
- **[Infrastructure Overview](agent_skills/infra.md)**\
When to read this:
- You need to understand where a feature belongs in the architecture.
- Youre wiring storage, Redis, vector stores, or OTEL.
- Youre about to add CLI commands or async jobs.\
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
- **[Coding Style](agent_skills/coding_style.md)**\
When to read this:
- Youre writing or reviewing backend code and need the authoritative checklist.
- Youre unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- You want the exact lint/type/test commands used in PRs.\
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
______________________________________________________________________
## Plugin & Extension Development
- **[Plugin Systems](agent_skills/plugin.md)**\
When to read this:
- Youre building or debugging a marketplace plugin.
- You need to know how manifests, providers, daemons, and migrations fit together.\
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
When to read this:
- You must integrate OAuth for a plugin or datasource.
- Youre handling credential encryption or refresh flows.\
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
______________________________________________________________________
## Workflow Entry & Execution
- **[Trigger Concepts](agent_skills/trigger.md)**\
When to read this:
- Youre debugging why a workflow didnt start.
- Youre adding a new trigger type or hook.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
______________________________________________________________________
## Additional Notes for Agents
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.

View File

@@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
@@ -44,14 +48,22 @@ ENV PYTHONIOENCODING=utf-8
WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
# install fonts to support the use of tools like pypdfium2
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
@@ -63,24 +75,29 @@ RUN \
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
# Copy source code
COPY . /app/api/
COPY --chown=dify:dify . /app/api/
# Prepare entrypoint script
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENV NLTK_DATA=/usr/local/share/nltk_data
USER dify
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@@ -15,8 +15,8 @@
```bash
cd ../docker
cp middleware.env.example middleware.env
# change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
cd ../api
```
@@ -26,6 +26,10 @@
cp .env.example .env
```
> [!IMPORTANT]
>
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
1. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
@@ -80,10 +84,10 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
```bash
uv run celery -A app.celery beat

View File

@@ -0,0 +1,115 @@
## Linter
- Always follow `.ruff.toml`.
- Run `uv run ruff check --fix --unsafe-fixes`.
- Keep each line under 100 characters (including spaces).
## Code Style
- `snake_case` for variables and functions.
- `PascalCase` for classes.
- `UPPER_CASE` for constants.
## Rules
- Use Pydantic v2 standard.
- Use `uv` for package management.
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
- Prefer simple functions over classes for lightweight helpers.
- Keep files below 800 lines; split when necessary.
- Keep code readable—no clever hacks.
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
## Guiding Principles
- Mirror the projects layered architecture: controller → service → core/domain.
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
## SQLAlchemy Patterns
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
- Open sessions with context managers:
```python
from sqlalchemy.orm import Session
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
## Storage & External IO
- Access storage via `extensions.ext_storage.storage`.
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
## Pydantic Usage
- Define DTOs with Pydantic v2 models and forbid extras by default.
- Use `@field_validator` / `@model_validator` for domain rules.
- Example:
```python
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
class TriggerConfig(BaseModel):
endpoint: HttpUrl
secret: str
model_config = ConfigDict(extra="forbid")
@field_validator("secret")
def ensure_secret_prefix(cls, value: str) -> str:
if not value.startswith("dify_"):
raise ValueError("secret must start with dify_")
return value
```
## Generics & Protocols
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
## Error Handling & Logging
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
- Declare `logger = logging.getLogger(__name__)` at module top.
- Include tenant/app/workflow identifiers in log context.
- Log retryable events at `warning`, terminal failures at `error`.
## Tooling & Checks
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
- Type checks: `uv run --directory api --dev basedpyright`.
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Run all of the above before submitting your work.
## Controllers & Services
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
- Document non-obvious behaviour with concise comments.
## Miscellaneous
- Use `configs.dify_config` for configuration—never read environment variables directly.
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
- Keep experimental scripts under `dev/`; do not ship them in production builds.

96
api/agent_skills/infra.md Normal file
View File

@@ -0,0 +1,96 @@
## Configuration
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
## Dependencies
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
## Storage & Files
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
## Redis & Shared State
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
## Models
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
## Vector Stores
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
## Observability & OTEL
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
## Ops Integrations
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
## Controllers, Services, Core
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
## Plugins, Tools, Providers
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
## Async Workloads
see `agent_skills/trigger.md` for more detailed documentation.
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
## Database & Migrations
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
## CLI Commands
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
## When You Add Features
- Check for an existing helper or service before writing a new util.
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.

View File

@@ -0,0 +1 @@
// TBD

View File

@@ -0,0 +1 @@
// TBD

View File

@@ -0,0 +1,53 @@
## Overview
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
## Trigger nodes
- `UserInput`
- `Trigger Webhook`
- `Trigger Schedule`
- `Trigger Plugin`
### UserInput
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
### Trigger Webhook
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
### Trigger Schedule
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
### Trigger Plugin
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
## Worker Pool / Async Task
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
## Debug Strategy
Dify divided users into 2 groups: builders / end users.
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.

View File

@@ -1,7 +1,7 @@
import sys
def is_db_command():
def is_db_command() -> bool:
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
@@ -13,23 +13,12 @@ if is_db_command():
app = create_migrations_app()
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
#
# # gevent
# monkey.patch_all()
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
from app_factory import create_app

View File

@@ -1,6 +1,8 @@
import logging
import time
from opentelemetry.trace import get_current_span
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
@@ -18,6 +20,7 @@ def create_flask_app_with_configs() -> DifyApp:
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook
@dify_app.before_request
@@ -25,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
return dify_app
@@ -50,10 +70,12 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@@ -62,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
@@ -74,6 +97,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
@@ -82,6 +106,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
@@ -92,6 +117,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_otel,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

7
api/cnt_base.sh Executable file
View File

@@ -0,0 +1,7 @@
#!/bin/bash
set -euxo pipefail
for pattern in "Base" "TypeBase"; do
printf "%s " "$pattern"
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
done

View File

@@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
if not datasets.items:
break
except SQLAlchemyError:
raise
@@ -1137,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
return
all_files_on_storage = []
for storage_path in storage_paths:
@@ -1227,6 +1230,55 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.
@@ -1420,7 +1472,10 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
@click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
"""
Transform datasource credentials
"""
@@ -1431,9 +1486,14 @@ def transform_datasource_credentials():
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
@@ -1521,6 +1581,14 @@ def transform_datasource_credentials():
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
if not firecrawl_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
@@ -1576,6 +1644,14 @@ def transform_datasource_credentials():
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
if not jina_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
@@ -1583,7 +1659,7 @@ def transform_datasource_credentials():
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
provider="jinareader",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@@ -73,14 +73,14 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Default number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
description="Maximum number of requests per app per day",
default=5000,
)
class CodeExecutionSandboxConfig(BaseSettings):
@@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
default=80000,
default=400_000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@@ -174,6 +174,33 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class AsyncWorkflowConfig(BaseSettings):
"""
Configuration for async workflow
"""
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
description="Granularity for async workflow scheduler, "
"sometime, few users could block the queue due to some time-consuming tasks, "
"to avoid this, workflow can be suspended if needed, to achieve"
"this, a time-based checker is required, every granularity seconds, "
"the checker will check the workflow queue and suspend the workflow",
default=120,
ge=1,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
@@ -189,6 +216,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key",
)
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=600.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
@@ -258,6 +290,8 @@ class EndpointConfig(BaseSettings):
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
class FileAccessConfig(BaseSettings):
"""
@@ -326,12 +360,93 @@ class FileUploadConfig(BaseSettings):
default=10,
)
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
# Annotation Import Security Configurations
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed CSV file size for annotation import in megabytes",
default=2,
)
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
description="Maximum number of annotation records allowed in a single import",
default=10000,
)
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
description="Minimum number of annotation records required in a single import",
default=1,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of annotation import requests per minute per tenant",
default=5,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Maximum number of annotation import requests per hour per tenant",
default=20,
)
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
description="Maximum number of concurrent annotation import tasks per tenant",
default=2,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
"""
COOKIE_DOMAIN: str = Field(
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
default="",
)
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
default=False,
@@ -362,11 +477,11 @@ class HttpConfig(BaseSettings):
)
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@@ -489,7 +604,10 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field(
description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
)
LOG_DATEFORMAT: str | None = Field(
@@ -543,7 +661,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
# 1000 KiB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)
@@ -582,6 +700,11 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
@@ -766,7 +889,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field(
description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""",
default=3,
)
@@ -905,6 +1028,11 @@ class DataSetConfig(BaseSettings):
default=True,
)
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings):
"""
@@ -980,6 +1108,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task",
default=True,
)
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",
default=True,
)
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
description="Trigger provider refresh poller interval in minutes",
default=1,
)
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
description="Max trigger subscriptions to process per tick",
default=200,
)
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds",
default=60 * 60,
)
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds",
default=60 * 60,
)
class PositionConfig(BaseSettings):
@@ -1078,7 +1244,7 @@ class AccountConfig(BaseSettings):
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
@@ -1097,12 +1263,36 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,
@@ -1120,7 +1310,9 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
@@ -25,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
@@ -104,6 +106,12 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@@ -139,12 +147,12 @@ class DatabaseConfig(BaseSettings):
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
default="postgresql",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = (
@@ -197,21 +205,21 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1,
)
@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
connect_args = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
@@ -329,7 +337,9 @@ class MiddlewareConfig(
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,

View File

@@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@@ -0,0 +1,91 @@
"""Configuration for InterSystems IRIS vector database."""
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
class IrisVectorConfig(BaseSettings):
"""Configuration settings for IRIS vector database connection and pooling."""
IRIS_HOST: str | None = Field(
description="Hostname or IP address of the IRIS server.",
default="localhost",
)
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
description="Port number for IRIS connection.",
default=1972,
)
IRIS_USER: str | None = Field(
description="Username for IRIS authentication.",
default="_SYSTEM",
)
IRIS_PASSWORD: str | None = Field(
description="Password for IRIS authentication.",
default="Dify@1234",
)
IRIS_SCHEMA: str | None = Field(
description="Schema name for IRIS tables.",
default="dify",
)
IRIS_DATABASE: str | None = Field(
description="Database namespace for IRIS connection.",
default="USER",
)
IRIS_CONNECTION_URL: str | None = Field(
description="Full connection URL for IRIS (overrides individual fields if provided).",
default=None,
)
IRIS_MIN_CONNECTION: PositiveInt = Field(
description="Minimum number of connections in the pool.",
default=1,
)
IRIS_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the pool.",
default=3,
)
IRIS_TEXT_INDEX: bool = Field(
description="Enable full-text search index using %iFind.Index.Basic.",
default=True,
)
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
default="en",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:
values: Configuration dictionary
Returns:
Validated configuration dictionary
Raises:
ValueError: If required fields are missing or pool settings are invalid
"""
# Only validate required fields if IRIS is being used as the vector store
# This allows the config to be loaded even when IRIS is not in use
# vector_store = os.environ.get("VECTOR_STORE", "")
# We rely on Pydantic defaults for required fields if they are missing from env.
# Strict existence check is removed to allow defaults to work.
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
if min_conn > max_conn:
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
return values

View File

@@ -1,23 +1,24 @@
from enum import Enum
from enum import StrEnum
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,

View File

@@ -22,7 +22,17 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,
)
WEAVIATE_TOKENIZATION: str | None = Field(
description="Tokenization for Weaviate (default is word)",
default="word",
)

View File

@@ -1,4 +1,5 @@
from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
@@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
_doc_extensions: list[str]
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
_doc_extensions = {
"txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt")
_doc_extensions.add("ppt")
else:
_doc_extensions = [
_doc_extensions = {
"txt",
"markdown",
"md",
@@ -37,5 +53,18 @@ else:
"csv",
"vtt",
"properties",
]
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@@ -20,6 +20,7 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
}
languages = list(language_timezone_mapping.keys())
@@ -31,3 +32,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language."
raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

File diff suppressed because one or more lines are too long

View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
"""
@@ -41,3 +42,11 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."

View File

@@ -24,7 +24,7 @@ except ImportError:
)
else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore
magic = None # type: ignore[assignment]
from pydantic import BaseModel

View File

@@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@@ -1,31 +1,10 @@
from importlib import import_module
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from .explore.message import (
MessageFeedbackApi,
MessageListApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from .explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(
@@ -35,23 +14,23 @@ api = ExternalApi(
description="Console management APIs for app configuration, monitoring, and administration",
)
# Create namespace
console_ns = Namespace("console", description="Console management API operations", path="/")
# File
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
RESOURCE_MODULES = (
"controllers.console.app.app_import",
"controllers.console.explore.audio",
"controllers.console.explore.completion",
"controllers.console.explore.conversation",
"controllers.console.explore.message",
"controllers.console.explore.workflow",
"controllers.console.files",
"controllers.console.remote_files",
)
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
for module_name in RESOURCE_MODULES:
import_module(module_name)
# Ensure resource modules are imported so route decorators are evaluated.
# Import other controllers
from . import (
admin,
@@ -87,6 +66,7 @@ from .app import (
workflow_draft_variable,
workflow_run,
workflow_statistic,
workflow_trigger,
)
# Import auth controllers
@@ -147,80 +127,10 @@ from .workspace import (
models,
plugin,
tool_providers,
trigger_providers,
workspace,
)
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# Explore Completion
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
# Explore Conversation
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
# Explore Message
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
# Explore Workflow
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
api.add_namespace(console_ns)
__all__ = [
@@ -288,6 +198,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trigger_providers",
"version",
"website",
"workflow",
@@ -295,5 +206,6 @@ __all__ = [
"workflow_draft_variable",
"workflow_run",
"workflow_statistic",
"workflow_trigger",
"workspace",
]

View File

@@ -3,19 +3,46 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@@ -24,19 +51,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization")
if auth_header is None:
auth_token = extract_access_token(request)
if not auth_token:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
@@ -47,59 +64,36 @@ def admin_required(view: Callable[P, R]):
@console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource):
@api.doc("insert_explore_app")
@api.doc(description="Insert or update an app in the explore list")
@api.expect(
api.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@api.response(200, "App updated successfully")
@api.response(201, "App inserted successfully")
@api.response(404, "App not found")
@console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
@console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("desc", type=str, location="json")
parser.add_argument("copyright", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
desc = payload.desc or ""
copy_right = payload.copyright or ""
privacy_policy = payload.privacy_policy or ""
custom_disclaimer = payload.custom_disclaimer or ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
desc = site.description or payload.desc or ""
copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
if not recommended_app:
@@ -109,9 +103,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args["language"],
category=args["category"],
position=args["position"],
language=payload.language,
category=payload.category,
position=payload.position,
)
db.session.add(recommended_app)
@@ -125,9 +119,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
recommended_app.language = payload.language
recommended_app.category = payload.category
recommended_app.position = payload.position
app.is_public = True
@@ -138,14 +132,14 @@ class InsertExploreAppListApi(Resource):
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource):
@api.doc("delete_explore_app")
@api.doc(description="Remove an app from the explore list")
@api.doc(params={"app_id": "Application ID to remove"})
@api.response(204, "App removed successfully")
@console_ns.doc("delete_explore_app")
@console_ns.doc(description="Remove an app from the explore list")
@console_ns.doc(params={"app_id": "Application ID to remove"})
@console_ns.response(204, "App removed successfully")
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@@ -153,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@@ -1,5 +1,4 @@
import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api, console_ns
from .wraps import account_initialization_required, setup_required
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
@@ -25,6 +24,12 @@ api_key_fields = {
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App:
@@ -53,11 +58,13 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list)
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@@ -65,14 +72,13 @@ class BaseApiKeyListResource(Resource):
).all()
return {"items": keys}
@marshal_with(api_key_fields)
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@@ -89,7 +95,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -104,13 +110,11 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None
resource_id_field: str | None = None
def delete(self, resource_id, api_key_id):
def delete(self, resource_id: str, api_key_id: str):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -135,28 +139,23 @@ class BaseApiKeyResource(Resource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_app_api_keys")
@api.doc(description="Get all API keys for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@api.doc("create_app_api_key")
@api.doc(description="Create a new API key for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -165,19 +164,14 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class AppApiKeyResource(BaseApiKeyResource):
@api.doc("delete_app_api_key")
@api.doc(description="Delete an API key for an app")
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
@console_ns.doc("delete_app_api_key")
@console_ns.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@@ -185,28 +179,23 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_dataset_api_keys")
@api.doc(description="Get all API keys for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@api.doc("create_dataset_api_key")
@api.doc(description="Create a new API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
@@ -215,19 +204,14 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class DatasetApiKeyResource(BaseApiKeyResource):
@api.doc("delete_dataset_api_key")
@api.doc(description="Delete an API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
@console_ns.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@@ -1,35 +1,39 @@
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
app_mode: str = Field(..., description="Application mode")
model_mode: str = Field(..., description="Model mode")
has_context: str = Field(default="true", description="Whether has context")
model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.response(
@console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@api.response(400, "Invalid request parameters")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
args = parser.parse_args()
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args)
return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@@ -1,6 +1,8 @@
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
@@ -8,29 +10,40 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource):
@api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
@console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@@ -1,59 +1,114 @@
from typing import Literal
from typing import Any, Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@api.doc("annotation_reply_action")
@api.doc(description="Enable or disable annotation reply for an app")
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@api.expect(
api.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@api.response(200, "Action completed successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
@console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -61,18 +116,16 @@ class AnnotationReplyActionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
class AppAnnotationSettingDetailApi(Resource):
@api.doc("get_annotation_setting")
@api.doc(description="Get annotation settings for an app")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotation settings retrieved successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("get_annotation_setting")
@console_ns.doc(description="Get annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Annotation settings retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
@@ -80,54 +133,39 @@ class AppAnnotationSettingDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
class AppAnnotationSettingUpdateApi(Resource):
@api.doc("update_annotation_setting")
@api.doc(description="Update annotation settings for an app")
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@api.expect(
api.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@api.response(200, "Settings updated successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
@console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
class AnnotationReplyActionStatusApi(Resource):
@api.doc("get_annotation_reply_action_status")
@api.doc(description="Get status of annotation reply action job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
@api.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("get_annotation_reply_action_status")
@console_ns.doc(description="Get status of annotation reply action job")
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
@console_ns.response(200, "Job status retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@@ -145,27 +183,21 @@ class AnnotationReplyActionStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations")
class AnnotationApi(Resource):
@api.doc("list_annotations")
@api.doc(description="Get annotations for an app with pagination")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@api.response(200, "Annotations retrieved successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
@console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
keyword = args.keyword
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@@ -178,45 +210,30 @@ class AnnotationApi(Resource):
}
return response, 200
@api.doc("create_annotation")
@api.doc(description="Create a new annotation for an app")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@api.response(201, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly
@@ -241,57 +258,61 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@api.doc("export_annotations")
@api.doc(description="Export all annotations for an app")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
@api.response(403, "Insufficient permissions")
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation")
@api.doc(description="Update or delete an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@@ -300,21 +321,26 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@api.doc("batch_import_annotations")
@api.doc(description="Batch import annotations from CSV file")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Batch import started successfully")
@api.response(403, "Insufficient permissions")
@api.response(400, "No file uploaded or too many files")
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@@ -324,27 +350,43 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
class AnnotationBatchImportStatusApi(Resource):
@api.doc("get_batch_import_status")
@api.doc(description="Get status of batch import job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
@api.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("get_batch_import_status")
@console_ns.doc(description="Get status of batch import job")
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
@console_ns.response(200, "Job status retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key)
@@ -361,25 +403,32 @@ class AnnotationBatchImportStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
class AnnotationHitHistoryListApi(Resource):
@api.doc("list_annotation_hit_histories")
@api.doc(description="Get hit histories for an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.expect(
api.parser()
@console_ns.doc("list_annotation_hit_histories")
@console_ns.doc(description="Get hit histories for an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
)
@api.response(
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
)
@api.response(403, "Insufficient permissions")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)

View File

@@ -1,101 +1,280 @@
import uuid
from typing import cast
from typing import Literal
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from werkzeug.exceptions import BadRequest
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from models import Account, App
from fields.app_fields import (
deleted_tool_fields,
model_config_fields,
model_config_partial_fields,
site_fields,
tag_fields,
)
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
tag_model = console_ns.model("Tag", tag_fields)
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
model_config_model = console_ns.model("ModelConfig", model_config_fields)
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
site_model = console_ns.model("Site", site_fields)
app_partial_model = console_ns.model(
"AppPartial",
{
"id": fields.String,
"name": fields.String,
"max_active_requests": fields.Raw(),
"description": fields.String(attribute="desc_or_prompt"),
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_model)),
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
},
)
app_detail_model = console_ns.model(
"AppDetail",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"tracing": fields.Raw,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
},
)
app_detail_with_site_model = console_ns.model(
"AppDetailWithSite",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
"site": fields.Nested(site_model),
},
)
app_pagination_model = console_ns.model(
"AppPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
},
)
@console_ns.route("/apps")
class AppListApi(Resource):
@api.doc("list_apps")
@api.doc(description="Get list of applications with pagination and filtering")
@api.expect(
api.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@api.response(200, "Success", app_pagination_fields)
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args_dict = args.model_dump()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@@ -109,71 +288,75 @@ class AppListApi(Resource):
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
return marshal(app_pagination, app_pagination_fields), 200
workflow_capable_app_ids = [
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
)
)
.scalars()
.all()
)
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
try:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
@api.doc("create_app")
@api.doc(description="Create a new application")
@api.expect(
api.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App created successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_model), 200
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
@marshal_with(app_detail_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201
@console_ns.route("/apps/<uuid:app_id>")
class AppApi(Resource):
@api.doc("get_app_detail")
@api.doc(description="Get application details")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Success", app_detail_fields_with_site)
@console_ns.doc("get_app_detail")
@console_ns.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Success", app_detail_with_site_model)
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
@@ -186,79 +369,50 @@ class AppApi(Resource):
return app_model
@api.doc("update_app")
@api.doc(description="Update application details")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@api.response(200, "App updated successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args()
args = UpdateAppPayload.model_validate(console_ns.payload)
app_service = AppService()
# Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType
args_dict: AppServiceType.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.max_active_requests or 0,
}
app_model = app_service.update_app(app_model, args_dict)
return app_model
@api.doc("delete_app")
@api.doc(description="Delete application")
@api.doc(params={"app_id": "Application ID"})
@api.response(204, "App deleted successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("delete_app")
@console_ns.doc(description="Delete application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(204, "App deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
app_service.delete_app(app_model)
@@ -267,55 +421,37 @@ class AppApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/copy")
class AppCopyApi(Resource):
@api.doc("copy_app")
@api.doc(description="Create a copy of an existing application")
@api.doc(params={"app_id": "Application ID to copy"})
@api.expect(
api.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App copied successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields_with_site)
@edit_permission_required
@marshal_with(app_detail_with_site_model)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
)
session.commit()
@@ -327,178 +463,131 @@ class AppCopyApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/export")
class AppExportApi(Resource):
@api.doc("export_app")
@api.doc(description="Export application configuration as DSL")
@api.doc(params={"app_id": "Application ID to export"})
@api.expect(
api.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@api.response(
@console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.response(
200,
"App exported successfully",
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
)
@api.response(403, "Insufficient permissions")
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@api.doc("check_app_name")
@api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
@api.response(200, "Name availability checked")
@console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_name(app_model, args["name"])
app_model = app_service.update_app_name(app_model, args.name)
return app_model
@console_ns.route("/apps/<uuid:app_id>/icon")
class AppIconApi(Resource):
@api.doc("update_app_icon")
@api.doc(description="Update application icon")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(200, "Icon updated successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
@console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model
@console_ns.route("/apps/<uuid:app_id>/site-enable")
class AppSiteStatus(Resource):
@api.doc("update_app_site_status")
@api.doc(description="Enable or disable app site")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@api.response(200, "Site status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
args = AppSiteStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model
@console_ns.route("/apps/<uuid:app_id>/api-enable")
class AppApiStatus(Resource):
@api.doc("update_app_api_status")
@api.doc(description="Enable or disable app API")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@api.response(200, "API status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@marshal_with(app_detail_model)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
args = AppApiStatusPayload.model_validate(console_ns.payload)
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model
@console_ns.route("/apps/<uuid:app_id>/trace")
class AppTraceApi(Resource):
@api.doc("get_app_trace")
@api.doc(description="Get app tracing configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Trace configuration retrieved successfully")
@console_ns.doc("get_app_trace")
@console_ns.doc(description="Get app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Trace configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@@ -508,36 +597,24 @@ class AppTraceApi(Resource):
return app_trace_config
@api.doc("update_app_trace")
@api.doc(description="Update app tracing configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@api.response(200, "Trace configuration updated successfully")
@api.response(403, "Insufficient permissions")
@console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
@console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
# add app trace
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")
args = parser.parse_args()
args = AppTracePayload.model_validate(console_ns.payload)
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
enabled=args.enabled,
tracing_provider=args.tracing_provider,
)
return {"result": "success"}

View File

@@ -1,65 +1,91 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from fields.app_fields import (
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
import_mode=args.mode,
yaml_content=args.yaml_content,
yaml_url=args.yaml_url,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
@@ -67,47 +93,47 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@@ -1,11 +1,12 @@
import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -32,20 +33,41 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
class ChatMessageAudioApi(Resource):
@api.doc("chat_message_audio_transcript")
@api.doc(description="Transcript audio to text for chat messages")
@api.doc(params={"app_id": "App ID"})
@api.response(
@console_ns.doc("chat_message_audio_transcript")
@console_ns.doc(description="Transcript audio to text for chat messages")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.response(
200,
"Audio transcription successful",
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
)
@api.response(400, "Bad request - No audio uploaded or unsupported type")
@api.response(413, "Audio file too large")
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")
@setup_required
@login_required
@account_initialization_required
@@ -89,41 +111,26 @@ class ChatMessageAudioApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
class ChatMessageTextApi(Resource):
@api.doc("chat_message_text_to_speech")
@api.doc(description="Convert text to speech for chat messages")
@api.doc(params={"app_id": "App ID"})
@api.expect(
api.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@api.response(200, "Text to speech conversion successful")
@api.response(400, "Bad request - Invalid parameters")
@console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
@console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def post(self, app_model: App):
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
payload = TextToSpeechPayload.model_validate(console_ns.payload)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@@ -154,25 +161,25 @@ class ChatMessageTextApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
class TextModesApi(Resource):
@api.doc("get_text_to_speech_voices")
@api.doc(description="Get available TTS voices for a specific language")
@api.doc(params={"app_id": "App ID"})
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
@api.response(400, "Invalid language parameter")
@console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@console_ns.response(400, "Invalid language parameter")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
try:
parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args["language"],
language=args.language,
)
return response

View File

@@ -1,11 +1,13 @@
import logging
from typing import Any, Literal
from flask import request
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -15,9 +17,8 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
@@ -32,48 +33,66 @@ from libs.login import current_user, login_required
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
class CompletionMessageApi(Resource):
@api.doc("create_completion_message")
@api.doc(description="Generate completion message for debugging")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Completion generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App not found")
@console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
streaming = args["response_mode"] != "blocking"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
try:
@@ -108,10 +127,10 @@ class CompletionMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
class CompletionMessageStopApi(Resource):
@api.doc("stop_completion_message")
@api.doc(description="Stop a running completion message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully")
@console_ns.doc("stop_completion_message")
@console_ns.doc(description="Stop a running completion message generation")
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@console_ns.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@@ -119,57 +138,36 @@ class CompletionMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageApi(Resource):
@api.doc("create_chat_message")
@api.doc(description="Generate chat message for debugging")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Chat message generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App or conversation not found")
@console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
@console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
streaming = args["response_mode"] != "blocking"
streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)
@@ -210,10 +208,10 @@ class ChatMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
class ChatMessageStopApi(Resource):
@api.doc("stop_chat_message")
@api.doc(description="Stop a running chat message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully")
@console_ns.doc("stop_chat_message")
@console_ns.doc(description="Stop a running chat message generation")
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@console_ns.response(200, "Task stopped successfully")
@setup_required
@login_required
@account_initialization_required
@@ -221,6 +219,12 @@ class ChatMessageStopApi(Resource):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200

View File

@@ -1,116 +1,376 @@
from datetime import datetime
from typing import Literal
import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_detail_fields,
conversation_message_detail_fields,
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
feedback_stat_model = console_ns.model(
"FeedbackStat",
{
"like": fields.Integer,
"dislike": fields.Integer,
},
)
status_count_model = console_ns.model(
"StatusCount",
{
"success": fields.Integer,
"failed": fields.Integer,
"partial_success": fields.Integer,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
simple_model_config_model = console_ns.model(
"SimpleModelConfig",
{
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
},
)
model_config_model = console_ns.model(
"ModelConfig",
{
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",
{
"inputs": FilesContainedField,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Conversation models
conversation_fields_model = console_ns.model(
"Conversation",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"from_account_name": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotation": fields.Nested(annotation_model, allow_null=True),
"model_config": fields.Nested(simple_model_config_model),
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
},
)
conversation_pagination_model = console_ns.model(
"ConversationPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
},
)
conversation_message_detail_model = console_ns.model(
"ConversationMessageDetail",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_model),
"message": fields.Nested(message_detail_model, attribute="first_message"),
},
)
conversation_with_summary_model = console_ns.model(
"ConversationWithSummary",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"from_account_name": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_model),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
"status_count": fields.Nested(status_count_model),
},
)
conversation_with_summary_pagination_model = console_ns.model(
"ConversationWithSummaryPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
},
)
conversation_detail_model = console_ns.model(
"ConversationDetail",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_model),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
},
)
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
class CompletionConversationApi(Resource):
@api.doc("list_completion_conversations")
@api.doc(description="Get completion conversations with pagination and filtering")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@api.response(200, "Success", conversation_pagination_fields)
@api.response(403, "Insufficient permissions")
@console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
@console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
@marshal_with(conversation_pagination_model)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
if args["keyword"]:
if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args['keyword']}%"),
Message.answer.ilike(f"%{args['keyword']}%"),
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
)
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
assert account.timezone is not None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args["annotation_status"] == "annotated":
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
@@ -119,49 +379,46 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
class CompletionConversationDetailApi(Resource):
@api.doc("get_completion_conversation")
@api.doc(description="Get completion conversation details with messages")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_message_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@console_ns.doc("get_completion_conversation")
@console_ns.doc(description="Get completion conversation details with messages")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(200, "Success", conversation_message_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
@marshal_with(conversation_message_detail_model)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@api.doc("delete_completion_conversation")
@api.doc(description="Delete a completion conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@console_ns.doc("delete_completion_conversation")
@console_ns.doc(description="Delete a completion conversation")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(204, "Conversation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -171,63 +428,21 @@ class CompletionConversationDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
class ChatConversationApi(Resource):
@api.doc("list_chat_conversations")
@api.doc(description="Get chat conversations with pagination, filtering and summary")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@api.response(200, "Success", conversation_with_summary_pagination_fields)
@api.response(403, "Insufficient permissions")
@console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
@marshal_with(conversation_with_summary_pagination_model)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
subquery = (
db.session.query(
@@ -239,8 +454,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"
if args.keyword:
keyword_filter = f"%{args.keyword}%"
query = (
query.join(
Message,
@@ -260,58 +475,43 @@ class ChatConversationApi(Resource):
)
account = current_user
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
assert account.timezone is not None
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
if start_datetime_utc:
match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args.sort_by:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args["annotation_status"] == "not_annotated":
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
match args.sort_by:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
@@ -323,49 +523,46 @@ class ChatConversationApi(Resource):
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
class ChatConversationDetailApi(Resource):
@api.doc("get_chat_conversation")
@api.doc(description="Get chat conversation details")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@console_ns.doc("get_chat_conversation")
@console_ns.doc(description="Get chat conversation details")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(200, "Success", conversation_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
@marshal_with(conversation_detail_model)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@api.doc("delete_chat_conversation")
@api.doc(description="Delete a chat conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found")
@console_ns.doc("delete_chat_conversation")
@console_ns.doc(description="Delete a chat conversation")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(204, "Conversation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -374,6 +571,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@@ -1,47 +1,68 @@
from flask_restx import Resource, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
)
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
class ConversationVariablesApi(Resource):
@api.doc("get_conversation_variables")
@api.doc(description="Get conversation variables for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
@console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
# NOTE: This is a temporary solution to avoid performance issues.
page = 1

View File

@@ -1,9 +1,10 @@
from collections.abc import Sequence
from typing import Any
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -12,50 +13,80 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
@api.doc("generate_rule_config")
@api.doc(description="Generate rule configuration using LLM")
@api.expect(
api.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@api.response(200, "Rule configuration generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
@console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -71,42 +102,25 @@ class RuleGenerateApi(Resource):
@console_ns.route("/rule-code-generate")
class RuleCodeGenerateApi(Resource):
@api.doc("generate_rule_code")
@api.doc(description="Generate code rules using LLM")
@api.expect(
api.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@api.response(200, "Code rules generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
@console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -122,35 +136,24 @@ class RuleCodeGenerateApi(Resource):
@console_ns.route("/rule-structured-output-generate")
class RuleStructuredOutputGenerateApi(Resource):
@api.doc("generate_structured_output")
@api.doc(description="Generate structured output rules using LLM")
@api.expect(
api.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@api.response(200, "Structured output generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
@console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -166,101 +169,79 @@ class RuleStructuredOutputGenerateApi(Resource):
@console_ns.route("/instruction-generate")
class InstructionGenerateApi(Resource):
@api.doc("generate_instruction")
@api.doc(description="Generate instruction for workflow nodes or general use")
@api.expect(
api.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Instruction generated successfully")
@api.response(400, "Invalid request parameters or flow/workflow not found")
@api.response(402, "Provider quota exceeded")
@console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
@console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400
return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]]
node = [node for node in nodes if node["id"] == args.node_id]
if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400
return {"error": f"node {args.node_id} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
if args.node_id == "" and args.current != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
tenant_id=current_tenant_id,
flow_id=args.flow_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
if args.node_id != "" and args.current != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
tenant_id=current_tenant_id,
flow_id=args.flow_id,
node_id=args.node_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
workflow_service=WorkflowService(),
)
return {"error": "incompatible parameters"}, 400
@@ -276,27 +257,17 @@ class InstructionGenerateApi(Resource):
@console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource):
@api.doc("get_instruction_template")
@api.doc(description="Get instruction generation template")
@api.expect(
api.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Template retrieved successfully")
@api.response(400, "Invalid request parameters")
@console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template")
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
@console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
args = InstructionTemplatePayload.model_validate(console_ns.payload)
match args.type:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@@ -306,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args['type']}")
raise ValueError(f"Invalid type: {args.type}")

View File

@@ -1,119 +1,113 @@
import json
from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource):
@api.doc("get_app_mcp_server")
@api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_fields)
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
@api.doc("create_app_mcp_server")
@api.doc(description="Create MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
@login_required
@setup_required
@marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
description = args.get("description")
description = payload.description
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False),
parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
db.session.commit()
return server
@api.doc("update_app_mcp_server")
@api.doc(description="Update MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@marshal_with(app_server_fields)
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
if not server:
raise NotFound()
description = args.get("description")
description = payload.description
if description is None:
pass
elif not description:
@@ -121,34 +115,34 @@ class AppMCPServerController(Resource):
else:
server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = args["status"]
server.status = payload.status
db.session.commit()
return server
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
class AppMCPServerRefreshController(Resource):
@api.doc("refresh_app_mcp_server")
@api.doc(description="Refresh MCP server configuration and regenerate server code")
@api.doc(params={"server_id": "Server ID"})
@api.response(200, "MCP server refreshed successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
@marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
)
if not server:

View File

@@ -1,11 +1,13 @@
import logging
from typing import Literal
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@@ -16,74 +18,234 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
)
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_fields)),
}
@api.doc("list_chat_messages")
@api.doc(description="Get chat messages for a conversation with pagination")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@setup_required
@console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@marshal_with(message_infinite_scroll_pagination_fields)
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args["first_id"]:
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
)
@@ -98,7 +260,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
else:
@@ -106,12 +268,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@@ -129,40 +291,28 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
class MessageFeedbackApi(Resource):
@api.doc("create_message_feedback")
@api.doc(description="Create or update message feedback (like/dislike)")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@api.response(200, "Feedback updated successfully")
@api.response(404, "Message not found")
@api.response(403, "Insufficient permissions")
@console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def post(self, app_model):
if current_user is None:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args["message_id"])
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@@ -171,18 +321,23 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
if not args["rating"] and feedback:
if not args.rating and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
rating=rating_value,
content=args.content,
from_source="admin",
from_account_id=current_user.id,
)
@@ -193,56 +348,15 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")
@api.doc(description="Get count of message annotations for the app")
@api.doc(params={"app_id": "Application ID"})
@api.response(
@console_ns.doc("get_annotation_count")
@console_ns.doc(description="Get count of message annotations for the app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
"Annotation count retrieved successfully",
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
)
@get_app_model
@setup_required
@@ -256,20 +370,23 @@ class MessageAnnotationCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(Resource):
@api.doc("get_message_suggested_questions")
@api.doc(description="Get suggested questions for a message")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(
@console_ns.doc("get_message_suggested_questions")
@console_ns.doc(description="Get suggested questions for a message")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(
200,
"Suggested questions retrieved successfully",
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
)
@api.response(404, "Message or conversation not found")
@console_ns.response(404, "Message or conversation not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id)
try:
@@ -297,19 +414,59 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource):
@api.doc("get_message")
@api.doc(description="Get message details by ID")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found")
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
)
return export_data
except ValueError as e:
logger.exception("Parameter validation error in feedback export")
return {"error": f"Parameter validation error: {str(e)}"}, 400
except Exception as e:
logger.exception("Error exporting feedback data")
raise InternalServerError(str(e))
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource):
@console_ns.doc("get_message")
@console_ns.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
@console_ns.response(404, "Message not found")
@get_app_model
@marshal_with(message_detail_fields)
def get(self, app_model, message_id):
@setup_required
@login_required
@account_initialization_required
@marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@@ -2,31 +2,29 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
from models.account import Account
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@api.doc("update_app_model_config")
@api.doc(description="Update application model configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
@console_ns.doc("update_app_model_config")
@console_ns.doc(description="Update application model configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"ModelConfigRequest",
{
"provider": fields.String(description="Model provider"),
@@ -44,25 +42,20 @@ class ModelConfigResource(Resource):
},
)
)
@api.response(200, "Model configuration updated successfully")
@api.response(400, "Invalid configuration")
@api.response(404, "App not found")
@console_ns.response(200, "Model configuration updated successfully")
@console_ns.response(400, "Invalid configuration")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
@@ -90,16 +83,16 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@@ -124,7 +117,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
@@ -133,7 +126,7 @@ class ModelConfigResource(Resource):
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
@@ -141,7 +134,7 @@ class ModelConfigResource(Resource):
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@@ -172,6 +165,8 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@@ -1,12 +1,36 @@
from flask_restx import Resource, fields, reqparse
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource):
@@ -14,63 +38,46 @@ class TraceAppConfigApi(Resource):
Manage trace app configurations
"""
@api.doc("get_trace_app_config")
@api.doc(description="Get tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@api.response(
@console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
@api.response(400, "Invalid request parameters")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config:
return {"has_not_configured": True}
return trace_config
except Exception as e:
raise BadRequest(str(e))
@api.doc("create_trace_app_config")
@api.doc(description="Create a new tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@api.response(
@console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
@api.response(400, "Invalid request parameters or configuration already exists")
@console_ns.response(400, "Invalid request parameters or configuration already exists")
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@@ -80,33 +87,22 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
@api.doc("update_trace_app_config")
@api.doc(description="Update an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@api.response(400, "Invalid request parameters or configuration not found")
@console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_config", type=dict, required=True, location="json")
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@@ -114,27 +110,21 @@ class TraceAppConfigApi(Resource):
except Exception as e:
raise BadRequest(str(e))
@api.doc("delete_trace_app_config")
@api.doc(description="Delete an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@api.response(204, "Tracing configuration deleted successfully")
@api.response(400, "Invalid request parameters or configuration not found")
@console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@@ -1,86 +1,80 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Account, Site
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def parse_app_site_args():
parser = reqparse.RequestParser()
parser.add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
parser.add_argument(
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
)
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args()
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
@console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource):
@api.doc("update_app_site")
@api.doc(description="Update application site configuration")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@api.response(200, "Site configuration updated successfully", app_site_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "App not found")
@console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
@marshal_with(app_site_model)
def post(self, app_model):
args = parse_app_site_args()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise Forbidden()
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
@@ -103,12 +97,10 @@ class AppSite(Resource):
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
value = getattr(args, attr_name)
if value is not None:
setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()
@@ -118,30 +110,26 @@ class AppSite(Resource):
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
class AppSiteAccessTokenReset(Resource):
@api.doc("reset_app_site_access_token")
@api.doc(description="Reset access token for application site")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Access token reset successfully", app_site_fields)
@api.response(403, "Insufficient permissions (admin/owner required)")
@api.response(404, "App or site not found")
@console_ns.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
@marshal_with(app_site_model)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()

View File

@@ -1,33 +1,48 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from flask import abort, jsonify, request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from models import AppMode, Message
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource):
@api.doc("get_daily_message_statistics")
@api.doc(description="Get daily message statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily message statistics retrieved successfully",
fields.List(fields.Raw(description="Daily message count data")),
@@ -37,43 +52,32 @@ class DailyMessageStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(*) AS message_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -91,15 +95,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily conversation statistics retrieved successfully",
fields.List(fields.Raw(description="Daily conversation count data")),
@@ -109,63 +109,53 @@ class DailyConversationStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at < end_datetime_utc)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
stmt = stmt.group_by("date").order_by("date")
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(stmt, {"tz": account.timezone})
for row in rs:
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily terminal statistics retrieved successfully",
fields.List(fields.Raw(description="Daily terminal count data")),
@@ -175,43 +165,32 @@ class DailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -229,15 +208,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Daily token cost statistics retrieved successfully",
fields.List(fields.Raw(description="Daily token cost data")),
@@ -247,15 +222,13 @@ class DailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
@@ -263,28 +236,19 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -304,15 +268,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Average session interaction statistics retrieved successfully",
fields.List(fields.Raw(description="Average session interaction data")),
@@ -322,15 +282,13 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(subquery.message_count) AS interactions
FROM
(
@@ -345,28 +303,19 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -395,15 +344,11 @@ ORDER BY
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"User satisfaction rate statistics retrieved successfully",
fields.List(fields.Raw(description="User satisfaction rate data")),
@@ -413,15 +358,13 @@ class UserSatisfactionRateStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
@@ -432,28 +375,19 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -476,15 +410,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Average response time statistics retrieved successfully",
fields.List(fields.Raw(description="Average response time data")),
@@ -494,43 +424,32 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(provider_response_latency) AS latency
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@@ -548,15 +467,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
@console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response(
200,
"Tokens per second statistics retrieved successfully",
fields.List(fields.Raw(description="Tokens per second data")),
@@ -566,15 +481,12 @@ class TokensPerSecondStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
@@ -584,28 +496,19 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc

File diff suppressed because it is too large Load Diff

View File

@@ -1,82 +1,85 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
class WorkflowAppLogApi(Resource):
@api.doc("get_workflow_app_logs")
@api.doc(description="Get workflow application execution logs")
@api.doc(params={"app_id": "Application ID"})
@api.doc(
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
@console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_fields)
@marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
@@ -90,6 +93,7 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
detail=args.detail,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
)

View File

@@ -1,17 +1,19 @@
import logging
from typing import NoReturn
from collections.abc import Callable
from functools import wraps
from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
@@ -21,14 +23,34 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from libs.login import login_required
from models import App, AppMode
from models.account import Account
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@@ -57,20 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
@@ -139,8 +147,42 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
# Register models for flask_restx to avoid dict type issues in Swagger
workflow_draft_variable_without_value_model = console_ns.model(
"WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
)
def _api_prerequisite(f):
workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
workflow_draft_env_variable_list_model = console_ns.model(
"WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
)
workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
)
workflow_draft_variable_list_without_value_model = console_ns.model(
"WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
)
workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
workflow_draft_variable_list_fields_copy["items"] = fields.List(
fields.Nested(workflow_draft_variable_model), attribute=_get_items
)
workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@@ -154,11 +196,10 @@ def _api_prerequisite(f):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs):
return f(*args, **kwargs)
return wrapper
@@ -166,19 +207,21 @@ def _api_prerequisite(f):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
@api.doc("get_workflow_variables")
@api.doc(description="Get draft workflow variables")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
@api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
@console_ns.response(
200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
)
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, app_model: App):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@@ -199,9 +242,9 @@ class WorkflowVariableCollectionApi(Resource):
return workflow_vars
@api.doc("delete_workflow_variables")
@api.doc(description="Delete all draft workflow variables")
@api.response(204, "Workflow variables deleted successfully")
@console_ns.doc("delete_workflow_variables")
@console_ns.doc(description="Delete all draft workflow variables")
@console_ns.response(204, "Workflow variables deleted successfully")
@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
@@ -232,12 +275,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
class NodeVariableCollectionApi(Resource):
@api.doc("get_node_variables")
@api.doc(description="Get variables for a specific node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@console_ns.doc("get_node_variables")
@console_ns.doc(description="Get variables for a specific node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@@ -248,9 +291,9 @@ class NodeVariableCollectionApi(Resource):
return node_vars
@api.doc("delete_node_variables")
@api.doc(description="Delete all variables for a specific node")
@api.response(204, "Node variables deleted successfully")
@console_ns.doc("delete_node_variables")
@console_ns.doc(description="Delete all variables for a specific node")
@console_ns.response(204, "Node variables deleted successfully")
@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
@@ -265,13 +308,13 @@ class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@api.doc("get_variable")
@api.doc(description="Get a specific workflow variable")
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(404, "Variable not found")
@console_ns.doc("get_variable")
@console_ns.doc(description="Get a specific workflow variable")
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@@ -283,21 +326,13 @@ class VariableApi(Resource):
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@api.doc("update_variable")
@api.doc(description="Update a workflow variable")
@api.expect(
api.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(404, "Variable not found")
@console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable")
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
@@ -320,15 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
@@ -336,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable
@@ -358,10 +388,10 @@ class VariableApi(Resource):
db.session.commit()
return variable
@api.doc("delete_variable")
@api.doc(description="Delete a workflow variable")
@api.response(204, "Variable deleted successfully")
@api.response(404, "Variable not found")
@console_ns.doc("delete_variable")
@console_ns.doc(description="Delete a workflow variable")
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
@@ -379,12 +409,12 @@ class VariableApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class VariableResetApi(Resource):
@api.doc("reset_variable")
@api.doc(description="Reset a workflow variable to its default value")
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(204, "Variable reset (no content)")
@api.response(404, "Variable not found")
@console_ns.doc("reset_variable")
@console_ns.doc(description="Reset a workflow variable to its default value")
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
@@ -408,7 +438,7 @@ class VariableResetApi(Resource):
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
return marshal(resetted, workflow_draft_variable_model)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@@ -427,13 +457,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
class ConversationVariableCollectionApi(Resource):
@api.doc("get_conversation_variables")
@api.doc(description="Get conversation variables for workflow")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@api.response(404, "Draft workflow not found")
@console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
@console_ns.response(404, "Draft workflow not found")
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
@@ -449,23 +479,23 @@ class ConversationVariableCollectionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables")
@api.doc(description="Get system variables for workflow")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@console_ns.doc("get_system_variables")
@console_ns.doc(description="Get system variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
class EnvironmentVariableCollectionApi(Resource):
@api.doc("get_environment_variables")
@api.doc(description="Get environment variables for workflow")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Environment variables retrieved successfully")
@api.response(404, "Draft workflow not found")
@console_ns.doc("get_environment_variables")
@console_ns.doc(description="Get environment variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Environment variables retrieved successfully")
@console_ns.response(404, "Draft workflow not found")
@_api_prerequisite
def get(self, app_model: App):
"""

View File

@@ -1,90 +1,341 @@
from typing import cast
from typing import Literal, cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
advanced_chat_workflow_run_for_list_fields,
advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields,
workflow_run_for_list_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import login_required
from models import Account, App, AppMode, EndUser
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
# Models that depend on simple_account_fields
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
advanced_chat_workflow_run_for_list_model = console_ns.model(
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
)
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
workflow_run_node_execution_model = console_ns.model(
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
)
# Simple models without nested dependencies
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
# Pagination models that depend on list models
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
)
advanced_chat_workflow_run_pagination_model = console_ns.model(
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
)
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunListQuery(BaseModel):
last_id: str | None = Field(default=None, description="Last run ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc("get_advanced_chat_workflow_runs")
@api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@console_ns.doc("get_advanced_chat_workflow_runs")
@console_ns.doc(description="Get advanced chat workflow run list")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@console_ns.doc(
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(advanced_chat_workflow_run_pagination_fields)
@marshal_with(advanced_chat_workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get advanced chat app workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count")
@console_ns.doc(description="Get advanced chat workflow runs count statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs")
class WorkflowRunListApi(Resource):
@api.doc("get_workflow_runs")
@api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@console_ns.doc("get_workflow_runs")
@console_ns.doc(description="Get workflow run list")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@console_ns.doc(
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_pagination_fields)
@marshal_with(workflow_run_pagination_model)
def get(self, app_model: App):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@console_ns.doc("get_workflow_runs_count")
@console_ns.doc(description="Get workflow runs count statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc(
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_model)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
class WorkflowRunDetailApi(Resource):
@api.doc("get_workflow_run_detail")
@api.doc(description="Get workflow run detail")
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
@api.response(404, "Workflow run not found")
@console_ns.doc("get_workflow_run_detail")
@console_ns.doc(description="Get workflow run detail")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_detail_fields)
@marshal_with(workflow_run_detail_model)
def get(self, app_model: App, run_id):
"""
Get workflow run detail
@@ -99,16 +350,16 @@ class WorkflowRunDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
class WorkflowRunNodeExecutionListApi(Resource):
@api.doc("get_workflow_run_node_executions")
@api.doc(description="Get workflow run node execution list")
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
@api.response(404, "Workflow run not found")
@console_ns.doc("get_workflow_run_node_executions")
@console_ns.doc(description="Get workflow run node execution list")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_list_fields)
@marshal_with(workflow_run_node_execution_list_model)
def get(self, app_model: App, run_id):
"""
Get workflow run node execution list

View File

@@ -1,311 +1,194 @@
from datetime import datetime
from decimal import Decimal
from flask import abort, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.datetime_utils import parse_time_range
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
@api.doc("get_workflow_daily_runs_statistic")
@api.doc(description="Get workflow daily runs statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@api.response(200, "Daily runs statistics retrieved successfully")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
response_data = self._workflow_run_repo.get_daily_runs_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
class WorkflowDailyTerminalsStatistic(Resource):
@api.doc("get_workflow_daily_terminals_statistic")
@api.doc(description="Get workflow daily terminals statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@api.response(200, "Daily terminals statistics retrieved successfully")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
class WorkflowDailyTokenCostStatistic(Resource):
@api.doc("get_workflow_daily_token_cost_statistic")
@api.doc(description="Get workflow daily token cost statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@api.response(200, "Daily token cost statistics retrieved successfully")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
class WorkflowAverageAppInteractionStatistic(Resource):
@api.doc("get_workflow_average_app_interaction_statistic")
@api.doc(description="Get workflow average app interaction statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@api.response(200, "Average app interaction statistics retrieved successfully")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
arg_dict["start"] = start_datetime_utc
else:
sql_query = sql_query.replace("{{start}}", "")
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})

View File

@@ -0,0 +1,157 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger
from .. import console_ns
from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class Parser(BaseModel):
node_id: str
class ParserEnable(BaseModel):
trigger_id: str
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
node_id = args.node_id
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
@console_ns.route("/apps/<uuid:app_id>/triggers")
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
assert current_user.current_tenant_id is not None
trigger_id = args.trigger_id
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger

View File

@@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from libs.login import current_account_with_tenant
from models import App, AppMode
from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account)
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@@ -1,36 +1,57 @@
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()
active_check_parser.add_argument(
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
)
active_check_parser.add_argument(
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
@api.doc("check_activation_token")
@api.doc(description="Check if activation token is valid")
@api.expect(active_check_parser)
@api.response(
@console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response(
200,
"Success",
api.model(
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
@@ -39,11 +60,11 @@ class ActivateCheckApi(Resource):
),
)
def get(self):
args = active_check_parser.parse_args()
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
@@ -60,26 +81,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = reqparse.RequestParser()
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
active_parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate")
class ActivateApi(Resource):
@api.doc("activate_account")
@api.doc(description="Activate account with invitation token")
@api.expect(active_parser)
@api.response(
@console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response(
200,
"Account activated successfully",
api.model(
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
@@ -87,23 +97,23 @@ class ActivateApi(Resource):
},
),
)
@api.response(400, "Already activated or invalid token")
@console_ns.response(400, "Already activated or invalid token")
def post(self):
args = active_parser.parse_args()
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
account.name = args["name"]
account.name = args.name
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@@ -1,21 +1,36 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import api
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source")
class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@@ -33,41 +48,36 @@ class ApiKeyAuthDataSource(Resource):
return {"sources": []}
@console_ns.route("/api-key-auth/data-source/binding")
class ApiKeyAuthDataSourceBinding(Resource):
@setup_required
@login_required
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
_, current_tenant_id = current_account_with_tenant()
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required
@login_required
@account_initialization_required
@is_admin_or_owner_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")

View File

@@ -2,16 +2,14 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api, console_ns
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__)
@@ -30,23 +28,22 @@ def get_oauth_providers():
@console_ns.route("/oauth/data-source/<string:provider>")
class OAuthDataSource(Resource):
@api.doc("oauth_data_source")
@api.doc(description="Get OAuth authorization URL for data source provider")
@api.doc(params={"provider": "Data source provider name (notion)"})
@api.response(
@console_ns.doc("oauth_data_source")
@console_ns.doc(description="Get OAuth authorization URL for data source provider")
@console_ns.doc(params={"provider": "Data source provider name (notion)"})
@console_ns.response(
200,
"Authorization URL or internal setup success",
api.model(
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
@api.response(400, "Invalid provider")
@api.response(403, "Admin privileges required")
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@is_admin_or_owner_required
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
@@ -65,17 +62,17 @@ class OAuthDataSource(Resource):
@console_ns.route("/oauth/data-source/callback/<string:provider>")
class OAuthDataSourceCallback(Resource):
@api.doc("oauth_data_source_callback")
@api.doc(description="Handle OAuth callback from data source provider")
@api.doc(
@console_ns.doc("oauth_data_source_callback")
@console_ns.doc(description="Handle OAuth callback from data source provider")
@console_ns.doc(
params={
"provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider",
}
)
@api.response(302, "Redirect to console with result")
@api.response(400, "Invalid provider")
@console_ns.response(302, "Redirect to console with result")
@console_ns.response(400, "Invalid provider")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@@ -96,17 +93,17 @@ class OAuthDataSourceCallback(Resource):
@console_ns.route("/oauth/data-source/binding/<string:provider>")
class OAuthDataSourceBinding(Resource):
@api.doc("oauth_data_source_binding")
@api.doc(description="Bind OAuth data source with authorization code")
@api.doc(
@console_ns.doc("oauth_data_source_binding")
@console_ns.doc(description="Bind OAuth data source with authorization code")
@console_ns.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
)
@api.response(
@console_ns.response(
200,
"Data source binding success",
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or code")
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@@ -130,15 +127,15 @@ class OAuthDataSourceBinding(Resource):
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
class OAuthDataSourceSync(Resource):
@api.doc("oauth_data_source_sync")
@api.doc(description="Sync data from OAuth data source")
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@api.response(
@console_ns.doc("oauth_data_source_sync")
@console_ns.doc(description="Sync data from OAuth data source")
@console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@console_ns.response(
200,
"Data source sync success",
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or sync failed")
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required
@login_required
@account_initialization_required

View File

@@ -1,11 +1,12 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
@@ -14,101 +15,122 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args["language"] in languages:
language = args["language"]
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@console_ns.route("/email-register/validity")
class EmailRegisterCheckApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"])
token_data = AccountService.get_email_register_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"}
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args["email"])
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/email-register")
class EmailRegisterResetApi(Resource):
@setup_required
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"])
register_data = AccountService.get_email_register_data(args.token)
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
@@ -116,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
@@ -126,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args["password_confirm"])
account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
@@ -148,8 +170,3 @@ class EmailRegisterResetApi(Resource):
raise AccountInFreezeError()
return account
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
api.add_resource(EmailRegisterResetApi, "/email-register")

View File

@@ -2,11 +2,12 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import api, console_ns
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
@@ -18,30 +19,50 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@api.doc("send_forgot_password_email")
@api.doc(description="Send password reset email")
@api.expect(
api.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@api.response(
@console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email")
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
@console_ns.response(
200,
"Email sent successfully",
api.model(
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
@@ -50,30 +71,27 @@ class ForgotPasswordSendEmailApi(Resource):
},
),
)
@api.response(400, "Invalid email or rate limit exceeded")
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=args["email"],
email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -83,22 +101,13 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@api.doc("check_forgot_password_code")
@api.doc(description="Verify password reset code")
@api.expect(
api.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@api.response(
@console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code")
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
@console_ns.response(
200,
"Code verified successfully",
api.model(
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
@@ -107,80 +116,63 @@ class ForgotPasswordCheckApi(Resource):
},
),
)
@api.response(400, "Invalid code or token")
@console_ns.response(400, "Invalid code or token")
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
token_data = AccountService.get_reset_password_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
user_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@api.doc("reset_password")
@api.doc(description="Reset password with verification token")
@api.expect(
api.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@api.response(
@console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token")
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
@console_ns.response(
200,
"Password reset successfully",
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid token or password mismatch")
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@@ -188,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
@@ -221,8 +213,3 @@ class ForgotPasswordResetApi(Resource):
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@@ -1,13 +1,12 @@
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
from configs import dify_config
from constants.languages import languages
from controllers.console import api
from constants.languages import get_valid_language
from controllers.console import console_ns
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@@ -23,55 +22,98 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from models.account import Account
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login")
class LoginApi(Resource):
"""Resource for user login."""
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=str, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
args = parser.parse_args()
args = LoginPayload.model_validate(console_ns.payload)
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try:
if invitation:
data = invitation.get("data", {})
data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
if invitee_email != args.email:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args["email"], args["password"])
account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@@ -87,41 +129,58 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout")
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}
response = make_response({"result": "success"})
else:
AccountService.logout(account=account)
flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password")
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args["email"],
email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -130,30 +189,29 @@ class ResetPasswordSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@@ -162,28 +220,28 @@ class EmailCodeLoginSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
language = args.language
token_data = AccountService.get_email_code_login_data(args["token"])
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
if token_data["email"] != args.email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
if token_data["code"] != args.code:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
@@ -205,7 +263,9 @@ class EmailCodeLoginApi(Resource):
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
@@ -214,26 +274,37 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token")
class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()
# Get refresh token from cookie instead of request body
refresh_token = extract_refresh_token(request)
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
new_token_pair = AccountService.refresh_token(refresh_token)
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e:
return {"result": "fail", "data": str(e)}, 401
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token")
return {"result": "fail", "message": str(e)}, 401

View File

@@ -14,15 +14,19 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api, console_ns
from .. import console_ns
logger = logging.getLogger(__name__)
@@ -52,11 +56,13 @@ def get_oauth_providers():
@console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource):
@api.doc("oauth_login")
@api.doc(description="Initiate OAuth login process")
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
@api.response(302, "Redirect to OAuth authorization URL")
@api.response(400, "Invalid provider")
@console_ns.doc("oauth_login")
@console_ns.doc(description="Initiate OAuth login process")
@console_ns.doc(
params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
)
@console_ns.response(302, "Redirect to OAuth authorization URL")
@console_ns.response(400, "Invalid provider")
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
@@ -71,17 +77,17 @@ class OAuthLogin(Resource):
@console_ns.route("/oauth/authorize/<provider>")
class OAuthCallback(Resource):
@api.doc("oauth_callback")
@api.doc(description="Handle OAuth callback and complete login process")
@api.doc(
@console_ns.doc("oauth_callback")
@console_ns.doc(description="Handle OAuth callback and complete login process")
@console_ns.doc(
params={
"provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)",
}
)
@api.response(302, "Redirect to console with access token")
@api.response(400, "OAuth process failed")
@console_ns.response(302, "Redirect to console with access token")
@console_ns.response(400, "OAuth process failed")
def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@@ -130,11 +136,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
@@ -153,9 +159,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

Some files were not shown because too many files have changed in this diff Show More