| import React from 'react' |
| import { Drawer, Title, Text, Stack, Box, Progress, Table, ScrollArea } from '@mantine/core' |
| import { useXRD } from '../context/XRDContext' |
|
|
| const LogitDrawer = () => { |
| const { isLogitDrawerOpen, setIsLogitDrawerOpen, modelResults } = useXRD() |
| |
| if (!modelResults?.predictions?.phase_predictions) { |
| return null |
| } |
| |
| |
| const classificationPredictions = modelResults.predictions.phase_predictions |
| .filter(p => !p.is_lattice && p.confidence) |
| |
| const getColorForProbability = (prob) => { |
| if (prob >= 0.8) return 'green' |
| if (prob >= 0.5) return 'yellow' |
| return 'orange' |
| } |
| |
| return ( |
| <Drawer |
| opened={isLogitDrawerOpen} |
| onClose={() => setIsLogitDrawerOpen(false)} |
| position="right" |
| size="xl" |
| title={ |
| <Title order={3}>Class Logit Distributions</Title> |
| } |
| padding="xl" |
| > |
| <Stack gap="xl"> |
| <Text size="sm" c="dimmed"> |
| Detailed logit scores for each classification task. Note: Logits have been normalized with softmax. |
| </Text> |
| |
| {classificationPredictions.map((pred, idx) => ( |
| <Box key={idx}> |
| <Title order={4} mb="md">{pred.phase}</Title> |
| <Text size="sm" mb="md" c="dimmed"> |
| Top Prediction: <Text span fw={700} c="violet">{pred.predicted_class}</Text> ({(pred.confidence * 100).toFixed(2)}%) |
| </Text> |
| |
| {/* Display all probabilities for Crystal System */} |
| {pred.all_probabilities && ( |
| <ScrollArea> |
| <Table highlightOnHover> |
| <Table.Thead> |
| <Table.Tr> |
| <Table.Th>Rank</Table.Th> |
| <Table.Th>Class</Table.Th> |
| <Table.Th>Logits</Table.Th> |
| <Table.Th>Distribution</Table.Th> |
| </Table.Tr> |
| </Table.Thead> |
| <Table.Tbody> |
| {pred.all_probabilities.map((item, i) => ( |
| <Table.Tr key={i} style={{ |
| backgroundColor: i === 0 ? '#f3f0ff' : 'transparent', |
| fontWeight: i === 0 ? 600 : 400 |
| }}> |
| <Table.Td>{i + 1}</Table.Td> |
| <Table.Td style={{ fontFamily: 'monospace' }}>{item.class_name}</Table.Td> |
| <Table.Td style={{ fontFamily: 'monospace' }}> |
| {(item.probability * 100).toFixed(2)}% |
| </Table.Td> |
| <Table.Td style={{ width: '40%' }}> |
| <Progress |
| value={item.probability * 100} |
| color={getColorForProbability(item.probability)} |
| size="lg" |
| /> |
| </Table.Td> |
| </Table.Tr> |
| ))} |
| </Table.Tbody> |
| </Table> |
| </ScrollArea> |
| )} |
| |
| {/* Display top 10 probabilities for Space Group */} |
| {pred.top_probabilities && ( |
| <ScrollArea> |
| <Table highlightOnHover> |
| <Table.Thead> |
| <Table.Tr> |
| <Table.Th>Rank</Table.Th> |
| <Table.Th>Space Group</Table.Th> |
| <Table.Th>Symbol</Table.Th> |
| <Table.Th>Logits</Table.Th> |
| <Table.Th>Distribution</Table.Th> |
| </Table.Tr> |
| </Table.Thead> |
| <Table.Tbody> |
| {pred.top_probabilities.map((item, i) => ( |
| <Table.Tr key={i} style={{ |
| backgroundColor: i === 0 ? '#f3f0ff' : 'transparent', |
| fontWeight: i === 0 ? 600 : 400 |
| }}> |
| <Table.Td>{i + 1}</Table.Td> |
| <Table.Td style={{ fontFamily: 'monospace' }}>#{item.space_group_number}</Table.Td> |
| <Table.Td style={{ fontFamily: 'monospace' }}>{item.space_group_symbol}</Table.Td> |
| <Table.Td style={{ fontFamily: 'monospace' }}> |
| {(item.probability * 100).toFixed(2)}% |
| </Table.Td> |
| <Table.Td style={{ width: '35%' }}> |
| <Progress |
| value={item.probability * 100} |
| color={getColorForProbability(item.probability)} |
| size="lg" |
| /> |
| </Table.Td> |
| </Table.Tr> |
| ))} |
| </Table.Tbody> |
| </Table> |
| </ScrollArea> |
| )} |
| </Box> |
| ))} |
| |
| <Box mt="md" p="md" style={{ backgroundColor: '#f8f9fa', borderRadius: '8px' }}> |
| <Text size="xs" c="dimmed"> |
| <strong>Note:</strong> The model outputs raw logit scores for each possible class. |
| These scores are normalized using softmax to show relative confidence. For Crystal System, all 7 |
| classes are shown. For Space Group, the top 10 out of 230 possible groups are displayed. |
| </Text> |
| </Box> |
| </Stack> |
| </Drawer> |
| ) |
| } |
|
|
| export default LogitDrawer |
|
|